diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 00000000..968f6d44 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,6 @@ +# .git-blame-ignore-revs +# Use this file to ignore commits in git blame that are just formatting changes +# Configure with: git config blame.ignoreRevsFile .git-blame-ignore-revs + +# Switch to ruff formatting +fba8bda \ No newline at end of file diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 6a3c3c48..c70c9d9b 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -49,7 +49,7 @@ python3.11 -m pytest --cov=src --cov-report term --cov-report html --cov-report ### Formatting code ```bash -yapf -i -r -p . +python3.11 -m ruff format . ``` ### Running lints & type checking @@ -58,7 +58,7 @@ yapf -i -r -p . # Type checking python3.11 -m mypy . # Linting -python3.11 -m pylint $(git ls-files '*.py') +python3.11 -m ruff check . ``` ### Generating Docs diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5e0cfacc..f8a5599f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -48,10 +48,10 @@ jobs: source venv/bin/activate pip3 install --upgrade pip python3.10 -m pip install -e ".[dev]" - - name: Lint with pylint + - name: Lint with ruff run: | source venv/bin/activate - python3.10 -m pylint $(git ls-files '*.py') + python3.10 -m ruff check . - name: Lint with mypy run: | source venv/bin/activate @@ -96,7 +96,7 @@ jobs: source venv/bin/activate pip3 install --upgrade pip python3.10 -m pip install -e ".[dev]" - - name: Check Formatting + - name: Check Formatting with ruff run: | source venv/bin/activate - yapf -d -r -p . + python3.10 -m ruff format --check . diff --git a/.gitignore b/.gitignore index 6339aeb6..8dc880a4 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,8 @@ doc/dist .idea .vscode/* !.vscode/settings.json + +# Local development files +async.md +modernization.md +uv.lock diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 0b758b1a..00000000 --- a/.pylintrc +++ /dev/null @@ -1,431 +0,0 @@ -# This Pylint rcfile contains a best-effort configuration to uphold the -# best-practices and style described in the Google Python style guide: -# https://google.github.io/styleguide/pyguide.html -# -# Its canonical open-source location is: -# https://google.github.io/styleguide/pylintrc - -[MASTER] - -# Files or directories to be skipped. They should be base names, not paths. -ignore=third_party, setup.py - -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. -ignore-patterns= - -# Pickle collected data for later comparisons. -persistent=no - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Use multiple processes to speed up Pylint. -jobs=4 - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -confidence= - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -#enable= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=abstract-method, - apply-builtin, - arguments-differ, - attribute-defined-outside-init, - backtick, - bad-option-value, - basestring-builtin, - buffer-builtin, - c-extension-no-member, - consider-using-enumerate, - cmp-builtin, - cmp-method, - coerce-builtin, - coerce-method, - delslice-method, - div-method, - duplicate-code, - eq-without-hash, - execfile-builtin, - file-builtin, - filter-builtin-not-iterating, - fixme, - getslice-method, - global-statement, - hex-method, - idiv-method, - implicit-str-concat-in-sequence, - import-error, - import-self, - import-star-module-level, - inconsistent-return-statements, - input-builtin, - intern-builtin, - invalid-str-codec, - locally-disabled, - long-builtin, - long-suffix, - map-builtin-not-iterating, - misplaced-comparison-constant, - missing-function-docstring, - metaclass-assignment, - next-method-called, - next-method-defined, - no-absolute-import, - no-else-break, - no-else-continue, - no-else-raise, - no-else-return, - no-init, # added - no-member, - no-name-in-module, - no-self-use, - nonzero-method, - oct-method, - old-division, - old-ne-operator, - old-octal-literal, - old-raise-syntax, - parameter-unpacking, - print-statement, - raising-string, - range-builtin-not-iterating, - raw_input-builtin, - rdiv-method, - reduce-builtin, - relative-import, - reload-builtin, - round-builtin, - setslice-method, - signature-differs, - standarderror-builtin, - suppressed-message, - sys-max-int, - too-few-public-methods, - too-many-ancestors, - too-many-arguments, - too-many-boolean-expressions, - too-many-branches, - too-many-instance-attributes, - too-many-locals, - too-many-nested-blocks, - too-many-public-methods, - too-many-return-statements, - too-many-statements, - trailing-newlines, - unichr-builtin, - unicode-builtin, - unnecessary-pass, - unpacking-in-except, - useless-else-on-loop, - useless-object-inheritance, - useless-suppression, - using-cmp-argument, - wrong-import-order, - xrange-builtin, - zip-builtin-not-iterating, - import-outside-toplevel, - protected-access, - - -[REPORTS] - -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages -reports=no - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - - -[BASIC] - -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl - -# Regular expression matching correct function names -function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ - -# Regular expression matching correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct constant names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ - -# Regular expression matching correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class attribute names -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct inline iteration names -inlinevar-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ - -# Regular expression matching correct module names -module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ - -# Regular expression matching correct method names -method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=10 - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=100 - -# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt -# lines made too long by directives to pytype. - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=(?x)( - ^\s*(\#\ )??$| - ^\s*(from\s+\S+\s+)?import\s+.+$) - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=yes - -# Maximum number of lines in a module -max-module-lines=99999 - -# String used as indentation unit. The internal Google style guide mandates 2 -# spaces. Google's externaly-published style guide says 4, consistent with -# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google -# projects (like TensorFlow). -indent-string=' ' - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=TODO - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=yes - - -[VARIABLES] - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_,_cb - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging,absl.logging,tensorflow.io.logging - - -[SIMILARITIES] - -# Minimum lines number of a similarity. -min-similarity-lines=4 - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - - -[SPELLING] - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[IMPORTS] - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub, - TERMIOS, - Bastion, - rexec, - sets - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant, absl - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls, - class_ - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=builtins.StandardError, - builtins.Exception, - builtins.BaseException diff --git a/.vscode/settings.json b/.vscode/settings.json index 22791a04..cb13601f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,15 +1,14 @@ { - "python.linting.enabled": true, - "python.linting.pylintEnabled": false, - "python.formatting.provider": "yapf", - "python.formatting.yapfArgs": [ - "--style", - "{based_on_style: google, indent_width: 4}" - ], - "python.linting.pylintPath": "pylint", + "[python]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "ruff.lint.enable": true, + "ruff.format.enable": true, "python.envFile": "${workspaceFolder}/venv", - "editor.formatOnSave": true, - "python.linting.lintOnSave": true, "python.linting.mypyEnabled": true, "mypy.dmypyExecutable": "${workspaceFolder}/venv/bin/dmypy", "files.autoSave": "afterDelay", diff --git a/docs/theme/devsite_translator/html.py b/docs/theme/devsite_translator/html.py index 1e2553cd..227e6781 100755 --- a/docs/theme/devsite_translator/html.py +++ b/docs/theme/devsite_translator/html.py @@ -16,16 +16,16 @@ from sphinx.writers import html _DESCTYPE_NAMES = { - 'class': 'Classes', - 'data': 'Constants', - 'function': 'Functions', - 'method': 'Methods', - 'attribute': 'Attributes', - 'exception': 'Exceptions' + "class": "Classes", + "data": "Constants", + "function": "Functions", + "method": "Methods", + "attribute": "Attributes", + "exception": "Exceptions", } # Use the default translator for these node types -_RENDER_WITH_DEFAULT = ['method', 'staticmethod', 'attribute'] +_RENDER_WITH_DEFAULT = ["method", "staticmethod", "attribute"] class FiresiteHTMLTranslator(html.HTMLTranslator): @@ -39,83 +39,80 @@ class FiresiteHTMLTranslator(html.HTMLTranslator): def __init__(self, builder, *args, **kwds): html.HTMLTranslator.__init__(self, builder, *args, **kwds) - self.current_section = 'intro' + self.current_section = "intro" # This flag gets set to True at the start of a new 'section' tag, and then # back to False after the first object signature in the section is processed self.insert_header = False def visit_desc(self, node): - if node.parent.tagname == 'section': + if node.parent.tagname == "section": self.insert_header = True - if node['desctype'] != self.current_section: - self.body.append( - f"

{_DESCTYPE_NAMES[node['desctype']]}

") - self.current_section = node['desctype'] - if node['desctype'] in _RENDER_WITH_DEFAULT: + if node["desctype"] != self.current_section: + self.body.append(f"

{_DESCTYPE_NAMES[node['desctype']]}

") + self.current_section = node["desctype"] + if node["desctype"] in _RENDER_WITH_DEFAULT: html.HTMLTranslator.visit_desc(self, node) else: - self.body.append(self.starttag(node, 'table', - CLASS=node['objtype'])) + self.body.append(self.starttag(node, "table", CLASS=node["objtype"])) def depart_desc(self, node): - if node['desctype'] in _RENDER_WITH_DEFAULT: + if node["desctype"] in _RENDER_WITH_DEFAULT: html.HTMLTranslator.depart_desc(self, node) else: - self.body.append('\n\n') + self.body.append("\n\n") def visit_desc_signature(self, node): - if node.parent['desctype'] in _RENDER_WITH_DEFAULT: + if node.parent["desctype"] in _RENDER_WITH_DEFAULT: html.HTMLTranslator.visit_desc_signature(self, node) else: - self.body.append('') - self.body.append(self.starttag(node, 'th')) + self.body.append("") + self.body.append(self.starttag(node, "th")) if self.insert_header: - self.body.append( - f"

{node['fullname']}

") + self.body.append(f'

{node["fullname"]}

') self.insert_header = False def depart_desc_signature(self, node): - if node.parent['desctype'] in _RENDER_WITH_DEFAULT: + if node.parent["desctype"] in _RENDER_WITH_DEFAULT: html.HTMLTranslator.depart_desc_signature(self, node) else: - self.body.append('') + self.body.append("") def visit_desc_content(self, node): - if node.parent['desctype'] in _RENDER_WITH_DEFAULT: + if node.parent["desctype"] in _RENDER_WITH_DEFAULT: html.HTMLTranslator.visit_desc_content(self, node) else: - self.body.append('') - self.body.append(self.starttag(node, 'td')) + self.body.append("") + self.body.append(self.starttag(node, "td")) def depart_desc_content(self, node): - if node.parent['desctype'] in _RENDER_WITH_DEFAULT: + if node.parent["desctype"] in _RENDER_WITH_DEFAULT: html.HTMLTranslator.depart_desc_content(self, node) else: - self.body.append('') + self.body.append("") def visit_title(self, node): - if node.parent.tagname == 'section': + if node.parent.tagname == "section": self.body.append('

') else: html.HTMLTranslator.visit_title(self, node) def depart_title(self, node): - if node.parent.tagname == 'section': - self.body.append('

') + if node.parent.tagname == "section": + self.body.append("") else: html.HTMLTranslator.depart_title(self, node) def visit_note(self, node): - self.body.append(self.starttag(node, 'aside', CLASS='note')) + self.body.append(self.starttag(node, "aside", CLASS="note")) def depart_note(self, node): # pylint: disable=unused-argument - self.body.append('\n\n') + self.body.append("\n\n") def visit_warning(self, node): - self.body.append(self.starttag(node, 'aside', CLASS='caution')) + self.body.append(self.starttag(node, "aside", CLASS="caution")) def depart_warning(self, node): # pylint: disable=unused-argument - self.body.append('\n\n') + self.body.append("\n\n") diff --git a/example/functions/main.py b/example/functions/main.py index 5ac35f3a..cbe77dfc 100644 --- a/example/functions/main.py +++ b/example/functions/main.py @@ -1,9 +1,11 @@ """ Example Firebase Functions written in Python """ -from firebase_functions import https_fn, options, params, pubsub_fn + from firebase_admin import initialize_app +from firebase_functions import https_fn, options, params, pubsub_fn + initialize_app() options.set_global_options( @@ -31,7 +33,8 @@ def oncallexample(req: https_fn.CallableRequest): return "Hello from https on call function example" -@pubsub_fn.on_message_published(topic="hello",) -def onmessagepublishedexample( - event: pubsub_fn.CloudEvent[pubsub_fn.MessagePublishedData]) -> None: +@pubsub_fn.on_message_published( + topic="hello", +) +def onmessagepublishedexample(event: pubsub_fn.CloudEvent[pubsub_fn.MessagePublishedData]) -> None: print("Hello from pubsub event:", event) diff --git a/pyproject.toml b/pyproject.toml index f2171b20..b3ddc861 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,12 +12,35 @@ pythonpath = [ [tool.coverage.report] skip_empty = true -[tool.yapf] -based_on_style = "google" -indent_width = 4 -[tool.yapfignore] -ignore_patterns = [ - "venv", - "build", - "dist", + +[tool.ruff] +target-version = "py310" +line-length = 100 +indent-width = 4 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "PL", # pylint ] +ignore = [ + "PLR0913", # Too many arguments + "PLR0912", # Too many branches + "PLR0915", # Too many statements + "PLR2004", # Magic value used in comparison + "PLW0603", # Using the global statement + "PLC0415", # Import outside toplevel + "E501", # Line too long (handled by formatter) +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/samples/basic_alerts/functions/main.py b/samples/basic_alerts/functions/main.py index 49aaf88f..dd5ab38c 100644 --- a/samples/basic_alerts/functions/main.py +++ b/samples/basic_alerts/functions/main.py @@ -1,36 +1,33 @@ """Cloud function samples for Firebase Alerts.""" from firebase_functions import alerts_fn -from firebase_functions.alerts import app_distribution_fn -from firebase_functions.alerts import billing_fn -from firebase_functions.alerts import crashlytics_fn -from firebase_functions.alerts import performance_fn +from firebase_functions.alerts import ( + app_distribution_fn, + billing_fn, + crashlytics_fn, + performance_fn, +) -@alerts_fn.on_alert_published( - alert_type=alerts_fn.AlertType.BILLING_PLAN_UPDATE) +@alerts_fn.on_alert_published(alert_type=alerts_fn.AlertType.BILLING_PLAN_UPDATE) def onalertpublished( - alert: alerts_fn.AlertEvent[alerts_fn.FirebaseAlertData[ - billing_fn.PlanUpdatePayload]] + alert: alerts_fn.AlertEvent[alerts_fn.FirebaseAlertData[billing_fn.PlanUpdatePayload]], ) -> None: print(alert) @app_distribution_fn.on_in_app_feedback_published() -def appdistributioninappfeedback( - alert: app_distribution_fn.InAppFeedbackEvent) -> None: +def appdistributioninappfeedback(alert: app_distribution_fn.InAppFeedbackEvent) -> None: print(alert) @app_distribution_fn.on_new_tester_ios_device_published() -def appdistributionnewrelease( - alert: app_distribution_fn.NewTesterDeviceEvent) -> None: +def appdistributionnewrelease(alert: app_distribution_fn.NewTesterDeviceEvent) -> None: print(alert) @billing_fn.on_plan_automated_update_published() -def billingautomatedplanupdate( - alert: billing_fn.BillingPlanAutomatedUpdateEvent) -> None: +def billingautomatedplanupdate(alert: billing_fn.BillingPlanAutomatedUpdateEvent) -> None: print(alert) @@ -40,42 +37,35 @@ def billingplanupdate(alert: billing_fn.BillingPlanUpdateEvent) -> None: @crashlytics_fn.on_new_fatal_issue_published() -def crashlyticsnewfatalissue( - alert: crashlytics_fn.CrashlyticsNewFatalIssueEvent) -> None: +def crashlyticsnewfatalissue(alert: crashlytics_fn.CrashlyticsNewFatalIssueEvent) -> None: print(alert) @crashlytics_fn.on_new_nonfatal_issue_published() -def crashlyticsnewnonfatalissue( - alert: crashlytics_fn.CrashlyticsNewNonfatalIssueEvent) -> None: +def crashlyticsnewnonfatalissue(alert: crashlytics_fn.CrashlyticsNewNonfatalIssueEvent) -> None: print(alert) @crashlytics_fn.on_new_anr_issue_published() -def crashlyticsnewanrissue( - alert: crashlytics_fn.CrashlyticsNewAnrIssueEvent) -> None: +def crashlyticsnewanrissue(alert: crashlytics_fn.CrashlyticsNewAnrIssueEvent) -> None: print(alert) @crashlytics_fn.on_regression_alert_published() -def crashlyticsregression( - alert: crashlytics_fn.CrashlyticsRegressionAlertEvent) -> None: +def crashlyticsregression(alert: crashlytics_fn.CrashlyticsRegressionAlertEvent) -> None: print(alert) @crashlytics_fn.on_stability_digest_published() -def crashlyticsstabilitydigest( - alert: crashlytics_fn.CrashlyticsStabilityDigestEvent) -> None: +def crashlyticsstabilitydigest(alert: crashlytics_fn.CrashlyticsStabilityDigestEvent) -> None: print(alert) @crashlytics_fn.on_velocity_alert_published() -def crashlyticsvelocity( - alert: crashlytics_fn.CrashlyticsVelocityAlertEvent) -> None: +def crashlyticsvelocity(alert: crashlytics_fn.CrashlyticsVelocityAlertEvent) -> None: print(alert) @performance_fn.on_threshold_alert_published() -def performancethreshold( - alert: performance_fn.PerformanceThresholdAlertEvent) -> None: +def performancethreshold(alert: performance_fn.PerformanceThresholdAlertEvent) -> None: print(alert) diff --git a/samples/basic_db/functions/main.py b/samples/basic_db/functions/main.py index 7fd134e0..d28d2914 100644 --- a/samples/basic_db/functions/main.py +++ b/samples/basic_db/functions/main.py @@ -1,9 +1,11 @@ """ Example Firebase Functions for RTDB written in Python """ -from firebase_functions import db_fn, options + from firebase_admin import initialize_app +from firebase_functions import db_fn, options + initialize_app() options.set_global_options(region=options.SupportedRegion.EUROPE_WEST1) diff --git a/samples/basic_eventarc/functions/main.py b/samples/basic_eventarc/functions/main.py index 6f716fa0..a471b485 100644 --- a/samples/basic_eventarc/functions/main.py +++ b/samples/basic_eventarc/functions/main.py @@ -1,9 +1,11 @@ """Firebase Cloud Functions for Eventarc triggers example.""" + from firebase_functions import eventarc_fn @eventarc_fn.on_custom_event_published( - event_type="firebase.extensions.storage-resize-images.v1.complete",) + event_type="firebase.extensions.storage-resize-images.v1.complete", +) def onimageresize(event: eventarc_fn.CloudEvent) -> None: """ Handle image resize events from the Firebase Storage Resize Images extension. diff --git a/samples/basic_firestore/functions/main.py b/samples/basic_firestore/functions/main.py index 7703ee65..778bb0df 100644 --- a/samples/basic_firestore/functions/main.py +++ b/samples/basic_firestore/functions/main.py @@ -1,17 +1,18 @@ """ Example Firebase Functions for Firestore written in Python """ -from firebase_functions import firestore_fn, options + from firebase_admin import initialize_app +from firebase_functions import firestore_fn, options + initialize_app() options.set_global_options(region=options.SupportedRegion.EUROPE_WEST1) @firestore_fn.on_document_written(document="hello/{world}") -def onfirestoredocumentwritten( - event: firestore_fn.Event[firestore_fn.Change]) -> None: +def onfirestoredocumentwritten(event: firestore_fn.Event[firestore_fn.Change]) -> None: print("Hello from Firestore document write event:", event) @@ -26,6 +27,5 @@ def onfirestoredocumentdeleted(event: firestore_fn.Event) -> None: @firestore_fn.on_document_updated(document="hello/world") -def onfirestoredocumentupdated( - event: firestore_fn.Event[firestore_fn.Change]) -> None: +def onfirestoredocumentupdated(event: firestore_fn.Event[firestore_fn.Change]) -> None: print("Hello from Firestore document updated event:", event) diff --git a/samples/basic_params/functions/main.py b/samples/basic_params/functions/main.py index 731e77b3..89dbf41e 100644 --- a/samples/basic_params/functions/main.py +++ b/samples/basic_params/functions/main.py @@ -1,9 +1,11 @@ """ Example Function params & inputs. """ -from firebase_functions import storage_fn, params + from firebase_admin import initialize_app +from firebase_functions import params, storage_fn + initialize_app() bucket = params.StringParam( @@ -17,13 +19,11 @@ output_path = params.StringParam( "OUTPUT_PATH", label="storage bucket output path", - description= - "The path of in the bucket where processed images will be stored.", + description="The path of in the bucket where processed images will be stored.", input=params.TextInput( example="/images/processed", validation_regex=r"^\/.*$", - validation_error_message= - "Must be a valid path starting with a forward slash", + validation_error_message="Must be a valid path starting with a forward slash", ), default="/images/processed", ) @@ -32,23 +32,26 @@ "IMAGE_TYPE", label="convert image to preferred types", description="The image types you'd like your source image to convert to.", - input=params.MultiSelectInput([ - params.SelectOption(value="jpeg", label="jpeg"), - params.SelectOption(value="png", label="png"), - params.SelectOption(value="webp", label="webp"), - ]), + input=params.MultiSelectInput( + [ + params.SelectOption(value="jpeg", label="jpeg"), + params.SelectOption(value="png", label="png"), + params.SelectOption(value="webp", label="webp"), + ] + ), default=["jpeg", "png"], ) delete_original = params.BoolParam( "DELETE_ORIGINAL_FILE", label="delete the original file", - description= - "Do you want to automatically delete the original file from the Cloud Storage?", - input=params.SelectInput([ - params.SelectOption(value=True, label="Delete on any resize attempt"), - params.SelectOption(value=False, label="Don't delete"), - ],), + description="Do you want to automatically delete the original file from the Cloud Storage?", + input=params.SelectInput( + [ + params.SelectOption(value=True, label="Delete on any resize attempt"), + params.SelectOption(value=False, label="Don't delete"), + ], + ), default=True, ) diff --git a/samples/basic_storage/functions/main.py b/samples/basic_storage/functions/main.py index f8be7b4e..b5e6e9cd 100644 --- a/samples/basic_storage/functions/main.py +++ b/samples/basic_storage/functions/main.py @@ -2,10 +2,11 @@ Example Firebase Functions for Storage triggers. """ -from firebase_functions import storage_fn -from firebase_functions.storage_fn import StorageObjectData, CloudEvent from firebase_admin import initialize_app +from firebase_functions import storage_fn +from firebase_functions.storage_fn import CloudEvent, StorageObjectData + initialize_app() diff --git a/samples/basic_tasks/functions/main.py b/samples/basic_tasks/functions/main.py index da5f922b..b314e561 100644 --- a/samples/basic_tasks/functions/main.py +++ b/samples/basic_tasks/functions/main.py @@ -5,8 +5,9 @@ from firebase_admin import initialize_app from google.cloud import tasks_v2 -from firebase_functions import tasks_fn, https_fn -from firebase_functions.options import SupportedRegion, RetryConfig, RateLimits + +from firebase_functions import https_fn, tasks_fn +from firebase_functions.options import RateLimits, RetryConfig, SupportedRegion app = initialize_app() @@ -50,14 +51,12 @@ def enqueuetask(req: https_fn.Request) -> https_fn.Response: "http_request": { "http_method": tasks_v2.HttpMethod.POST, "url": url, - "headers": { - "Content-type": "application/json" - }, + "headers": {"Content-type": "application/json"}, "body": json.dumps(body).encode(), }, - "schedule_time": - datetime.datetime.utcnow() + datetime.timedelta(minutes=1), - }) + "schedule_time": datetime.datetime.utcnow() + datetime.timedelta(minutes=1), + } + ) parent = client.queue_path( app.project_id, diff --git a/samples/basic_test_lab/functions/main.py b/samples/basic_test_lab/functions/main.py index 49f766e2..b8c5cba5 100644 --- a/samples/basic_test_lab/functions/main.py +++ b/samples/basic_test_lab/functions/main.py @@ -1,4 +1,5 @@ """Firebase Cloud Functions for Test Lab.""" + from firebase_functions.test_lab_fn import ( CloudEvent, TestMatrixCompletedData, @@ -13,14 +14,10 @@ def testmatrixcompleted(event: CloudEvent[TestMatrixCompletedData]) -> None: print(f"Test Matrix Outcome Summary: {event.data.outcome_summary}") print("Result Storage:") - print( - f" Tool Results History: {event.data.result_storage.tool_results_history}" - ) + print(f" Tool Results History: {event.data.result_storage.tool_results_history}") print(f" Results URI: {event.data.result_storage.results_uri}") print(f" GCS Path: {event.data.result_storage.gcs_path}") - print( - f" Tool Results Execution: {event.data.result_storage.tool_results_execution}" - ) + print(f" Tool Results Execution: {event.data.result_storage.tool_results_execution}") print("Client Info:") print(f" Client: {event.data.client_info.client}") diff --git a/samples/identity/functions/main.py b/samples/identity/functions/main.py index 9fa44f15..409df247 100644 --- a/samples/identity/functions/main.py +++ b/samples/identity/functions/main.py @@ -1,4 +1,5 @@ """Firebase Cloud Functions for blocking auth functions example.""" + from firebase_functions import identity_fn @@ -8,15 +9,19 @@ refresh_token=True, ) def beforeusercreated( - event: identity_fn.AuthBlockingEvent + event: identity_fn.AuthBlockingEvent, ) -> identity_fn.BeforeCreateResponse | None: print(event) if not event.data.email: return None if "@cats.com" in event.data.email: - return identity_fn.BeforeCreateResponse(display_name="Meow!",) + return identity_fn.BeforeCreateResponse( + display_name="Meow!", + ) if "@dogs.com" in event.data.email: - return identity_fn.BeforeCreateResponse(display_name="Woof!",) + return identity_fn.BeforeCreateResponse( + display_name="Woof!", + ) return None @@ -26,7 +31,7 @@ def beforeusercreated( refresh_token=True, ) def beforeusersignedin( - event: identity_fn.AuthBlockingEvent + event: identity_fn.AuthBlockingEvent, ) -> identity_fn.BeforeSignInResponse | None: print(event) if not event.data.email: diff --git a/setup.py b/setup.py index efbc8e0b..901c3898 100644 --- a/setup.py +++ b/setup.py @@ -14,56 +14,70 @@ """ Setup for Firebase Functions Python. """ + from os import path + from setuptools import find_packages, setup install_requires = [ - 'flask>=2.1.2', 'functions-framework>=3.0.0', 'firebase-admin>=6.0.0', - 'pyyaml>=6.0', 'typing-extensions>=4.4.0', 'cloudevents>=1.2.0,<2.0.0', - 'flask-cors>=3.0.10', 'pyjwt[crypto]>=2.5.0', 'google-events==0.5.0', - 'google-cloud-firestore>=2.11.0' + "flask>=2.1.2", + "functions-framework>=3.0.0", + "firebase-admin>=6.0.0", + "pyyaml>=6.0", + "typing-extensions>=4.4.0", + "cloudevents>=1.2.0,<2.0.0", + "flask-cors>=3.0.10", + "pyjwt[crypto]>=2.5.0", + "google-events==0.5.0", + "google-cloud-firestore>=2.11.0", ] dev_requires = [ - 'pytest>=7.1.2', 'setuptools>=63.4.2', 'pylint>=2.16.1', - 'pytest-cov>=3.0.0', 'mypy>=1.0.0', 'sphinx>=6.1.3', - 'sphinxcontrib-napoleon>=0.7', 'yapf>=0.32.0', 'toml>=0.10.2', - 'google-cloud-tasks>=2.13.1' + "pytest>=7.1.2", + "setuptools>=63.4.2", + "pytest-cov>=3.0.0", + "mypy>=1.0.0", + "sphinx>=6.1.3", + "sphinxcontrib-napoleon>=0.7", + "toml>=0.10.2", + "google-cloud-tasks>=2.13.1", + "ruff>=0.1.0", ] # Read in the package metadata per recommendations from: # https://packaging.python.org/guides/single-sourcing-package-version/ -init_path = path.join(path.dirname(path.abspath(__file__)), 'src', - 'firebase_functions', '__init__.py') +init_path = path.join( + path.dirname(path.abspath(__file__)), "src", "firebase_functions", "__init__.py" +) version = {} with open(init_path) as fp: exec(fp.read(), version) # pylint: disable=exec-used long_description = ( - 'The Firebase Functions Python SDK provides an SDK for defining' - ' Cloud Functions for Firebase.') + "The Firebase Functions Python SDK provides an SDK for defining Cloud Functions for Firebase." +) setup( - name='firebase_functions', - version=version['__version__'], - description='Firebase Functions Python SDK', + name="firebase_functions", + version=version["__version__"], + description="Firebase Functions Python SDK", long_description=long_description, - url='https://github.com/firebase/firebase-functions-python', - author='Firebase Team', - keywords=['firebase', 'functions', 'google', 'cloud'], - license='Apache License 2.0', + url="https://github.com/firebase/firebase-functions-python", + author="Firebase Team", + keywords=["firebase", "functions", "google", "cloud"], + license="Apache License 2.0", install_requires=install_requires, - extras_require={'dev': dev_requires}, - packages=find_packages(where='src'), - package_dir={'': 'src'}, + extras_require={"dev": dev_requires}, + packages=find_packages(where="src"), + package_dir={"": "src"}, include_package_data=True, - package_data={'firebase_functions': ['py.typed']}, - python_requires='>=3.10', + package_data={"firebase_functions": ["py.typed"]}, + python_requires=">=3.10", classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Build Tools', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Topic :: Software Development :: Build Tools", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], ) diff --git a/src/firebase_functions/alerts/app_distribution_fn.py b/src/firebase_functions/alerts/app_distribution_fn.py index 5ff3a394..d19f36f8 100644 --- a/src/firebase_functions/alerts/app_distribution_fn.py +++ b/src/firebase_functions/alerts/app_distribution_fn.py @@ -15,15 +15,16 @@ """ Cloud functions to handle Firebase App Distribution events from Firebase Alerts. """ + import dataclasses as _dataclasses import functools as _functools import typing as _typing + import cloudevents.http as _ce -from firebase_functions.alerts import FirebaseAlertData import firebase_functions.private.util as _util - -from firebase_functions.core import T, CloudEvent +from firebase_functions.alerts import FirebaseAlertData +from firebase_functions.core import CloudEvent, T from firebase_functions.options import AppDistributionOptions @@ -64,7 +65,7 @@ class InAppFeedbackPayload: feedback_report: str """ - Resource name. Format: + Resource name. Format: `projects/{project_number}/apps/{app_id}/releases/{release_id}/feedbackReports/{feedback_id}` """ @@ -127,8 +128,7 @@ class AppDistributionEvent(CloudEvent[FirebaseAlertData[T]]): The type of the event for 'on_in_app_feedback_published' functions. """ -OnNewTesterIosDevicePublishedCallable = _typing.Callable[[NewTesterDeviceEvent], - None] +OnNewTesterIosDevicePublishedCallable = _typing.Callable[[NewTesterDeviceEvent], None] """ The type of the callable for 'on_new_tester_ios_device_published' functions. """ @@ -141,9 +141,10 @@ class AppDistributionEvent(CloudEvent[FirebaseAlertData[T]]): @_util.copy_func_kwargs(AppDistributionOptions) def on_new_tester_ios_device_published( - **kwargs -) -> _typing.Callable[[OnNewTesterIosDevicePublishedCallable], - OnNewTesterIosDevicePublishedCallable]: + **kwargs, +) -> _typing.Callable[ + [OnNewTesterIosDevicePublishedCallable], OnNewTesterIosDevicePublishedCallable +]: """ Event handler which runs every time a new tester iOS device is added. @@ -151,7 +152,7 @@ def on_new_tester_ios_device_published( .. code-block:: python - import firebase_functions.alerts.app_distribution_fn as app_distribution_fn + import firebase_functions.alerts.app_distribution_fn as app_distribution_fn @app_distribution_fn.on_new_tester_ios_device_published() def example(alert: app_distribution_fn.NewTesterDeviceEvent) -> None: @@ -160,27 +161,28 @@ def example(alert: app_distribution_fn.NewTesterDeviceEvent) -> None: :param \\*\\*kwargs: Options. :type \\*\\*kwargs: as :exc:`firebase_functions.options.AppDistributionOptions` :rtype: :exc:`typing.Callable` - \\[ - \\[ :exc:`firebase_functions.alerts.app_distribution_fn.NewTesterDeviceEvent` \\], - `None` + \\[ + \\[ :exc:`firebase_functions.alerts.app_distribution_fn.NewTesterDeviceEvent` \\], + `None` \\] A function that takes a NewTesterDeviceEvent and returns None. """ options = AppDistributionOptions(**kwargs) def on_new_tester_ios_device_published_inner_decorator( - func: OnNewTesterIosDevicePublishedCallable): - + func: OnNewTesterIosDevicePublishedCallable, + ): @_functools.wraps(func) def on_new_tester_ios_device_published_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import app_distribution_event_from_ce + func(app_distribution_event_from_ce(raw)) _util.set_func_endpoint_attr( on_new_tester_ios_device_published_wrapped, options._endpoint( func_name=func.__name__, - alert_type='appDistribution.newTesterIosDevice', + alert_type="appDistribution.newTesterIosDevice", ), ) return on_new_tester_ios_device_published_wrapped @@ -190,9 +192,8 @@ def on_new_tester_ios_device_published_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(AppDistributionOptions) def on_in_app_feedback_published( - **kwargs -) -> _typing.Callable[[OnInAppFeedbackPublishedCallable], - OnInAppFeedbackPublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnInAppFeedbackPublishedCallable], OnInAppFeedbackPublishedCallable]: """ Event handler which runs every time new feedback is received. @@ -200,7 +201,7 @@ def on_in_app_feedback_published( .. code-block:: python - import firebase_functions.alerts.app_distribution_fn as app_distribution_fn + import firebase_functions.alerts.app_distribution_fn as app_distribution_fn @app_distribution_fn.on_in_app_feedback_published() def example(alert: app_distribution_fn.InAppFeedbackEvent) -> None: @@ -209,27 +210,26 @@ def example(alert: app_distribution_fn.InAppFeedbackEvent) -> None: :param \\*\\*kwargs: Options. :type \\*\\*kwargs: as :exc:`firebase_functions.options.AppDistributionOptions` :rtype: :exc:`typing.Callable` - \\[ - \\[ :exc:`firebase_functions.alerts.app_distribution_fn.InAppFeedbackEvent` \\], - `None` + \\[ + \\[ :exc:`firebase_functions.alerts.app_distribution_fn.InAppFeedbackEvent` \\], + `None` \\] A function that takes a NewTesterDeviceEvent and returns None. """ options = AppDistributionOptions(**kwargs) - def on_in_app_feedback_published_inner_decorator( - func: OnInAppFeedbackPublishedCallable): - + def on_in_app_feedback_published_inner_decorator(func: OnInAppFeedbackPublishedCallable): @_functools.wraps(func) def on_in_app_feedback_published_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import app_distribution_event_from_ce + func(app_distribution_event_from_ce(raw)) _util.set_func_endpoint_attr( on_in_app_feedback_published_wrapped, options._endpoint( func_name=func.__name__, - alert_type='appDistribution.inAppFeedback', + alert_type="appDistribution.inAppFeedback", ), ) return on_in_app_feedback_published_wrapped diff --git a/src/firebase_functions/alerts/billing_fn.py b/src/firebase_functions/alerts/billing_fn.py index 6c698ed5..6411e5f0 100644 --- a/src/firebase_functions/alerts/billing_fn.py +++ b/src/firebase_functions/alerts/billing_fn.py @@ -15,15 +15,16 @@ """ Cloud functions to handle billing events from Firebase Alerts. """ + import dataclasses as _dataclasses import functools as _functools import typing as _typing + import cloudevents.http as _ce -from firebase_functions.alerts import FirebaseAlertData import firebase_functions.private.util as _util - -from firebase_functions.core import T, CloudEvent +from firebase_functions.alerts import FirebaseAlertData +from firebase_functions.core import CloudEvent, T from firebase_functions.options import BillingOptions @@ -85,8 +86,7 @@ class BillingEvent(CloudEvent[FirebaseAlertData[T]]): The type of the callable for 'on_plan_update_published' functions. """ -OnPlanAutomatedUpdatePublishedCallable = _typing.Callable[ - [BillingPlanAutomatedUpdateEvent], None] +OnPlanAutomatedUpdatePublishedCallable = _typing.Callable[[BillingPlanAutomatedUpdateEvent], None] """ The type of the callable for 'on_plan_automated_update_published' functions. """ @@ -94,9 +94,8 @@ class BillingEvent(CloudEvent[FirebaseAlertData[T]]): @_util.copy_func_kwargs(BillingOptions) def on_plan_update_published( - **kwargs -) -> _typing.Callable[[OnPlanUpdatePublishedCallable], - OnPlanUpdatePublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnPlanUpdatePublishedCallable], OnPlanUpdatePublishedCallable]: """ Event handler which triggers when a Firebase Alerts billing event is published. @@ -104,7 +103,7 @@ def on_plan_update_published( .. code-block:: python - import firebase_functions.alerts.billing_fn as billing_fn + import firebase_functions.alerts.billing_fn as billing_fn @billing_fn.on_plan_update_published() def example(alert: billing_fn.BillingPlanUpdateEvent) -> None: @@ -113,27 +112,26 @@ def example(alert: billing_fn.BillingPlanUpdateEvent) -> None: :param \\*\\*kwargs: Options. :type \\*\\*kwargs: as :exc:`firebase_functions.options.BillingOptions` :rtype: :exc:`typing.Callable` - \\[ - \\[ :exc:`firebase_functions.alerts.billing_fn.BillingPlanUpdateEvent` \\], - `None` + \\[ + \\[ :exc:`firebase_functions.alerts.billing_fn.BillingPlanUpdateEvent` \\], + `None` \\] A function that takes a BillingPlanUpdateEvent and returns None. """ options = BillingOptions(**kwargs) - def on_plan_update_published_inner_decorator( - func: OnPlanUpdatePublishedCallable): - + def on_plan_update_published_inner_decorator(func: OnPlanUpdatePublishedCallable): @_functools.wraps(func) def on_plan_update_published_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import billing_event_from_ce + func(billing_event_from_ce(raw)) _util.set_func_endpoint_attr( on_plan_update_published_wrapped, options._endpoint( func_name=func.__name__, - alert_type='billing.planUpdate', + alert_type="billing.planUpdate", ), ) return on_plan_update_published_wrapped @@ -143,9 +141,10 @@ def on_plan_update_published_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(BillingOptions) def on_plan_automated_update_published( - **kwargs -) -> _typing.Callable[[OnPlanAutomatedUpdatePublishedCallable], - OnPlanAutomatedUpdatePublishedCallable]: + **kwargs, +) -> _typing.Callable[ + [OnPlanAutomatedUpdatePublishedCallable], OnPlanAutomatedUpdatePublishedCallable +]: """ Event handler which triggers when a Firebase Alerts billing event is published. @@ -153,7 +152,7 @@ def on_plan_automated_update_published( .. code-block:: python - import firebase_functions.alerts.billing_fn as billing_fn + import firebase_functions.alerts.billing_fn as billing_fn @billing_fn.on_plan_automated_update_published() def example(alert: billing_fn.BillingPlanAutomatedUpdateEvent) -> None: @@ -162,27 +161,28 @@ def example(alert: billing_fn.BillingPlanAutomatedUpdateEvent) -> None: :param \\*\\*kwargs: Options. :type \\*\\*kwargs: as :exc:`firebase_functions.options.BillingOptions` :rtype: :exc:`typing.Callable` - \\[ - \\[ :exc:`firebase_functions.alerts.billing_fn.BillingPlanAutomatedUpdateEvent` \\], - `None` + \\[ + \\[ :exc:`firebase_functions.alerts.billing_fn.BillingPlanAutomatedUpdateEvent` \\], + `None` \\] A function that takes a BillingPlanUpdateEvent and returns None. """ options = BillingOptions(**kwargs) def on_plan_automated_update_published_inner_decorator( - func: OnPlanAutomatedUpdatePublishedCallable): - + func: OnPlanAutomatedUpdatePublishedCallable, + ): @_functools.wraps(func) def on_plan_automated_update_published_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import billing_event_from_ce + func(billing_event_from_ce(raw)) _util.set_func_endpoint_attr( on_plan_automated_update_published_wrapped, options._endpoint( func_name=func.__name__, - alert_type='billing.planAutomatedUpdate', + alert_type="billing.planAutomatedUpdate", ), ) return on_plan_automated_update_published_wrapped diff --git a/src/firebase_functions/alerts/crashlytics_fn.py b/src/firebase_functions/alerts/crashlytics_fn.py index 915c4541..1d3d25f2 100644 --- a/src/firebase_functions/alerts/crashlytics_fn.py +++ b/src/firebase_functions/alerts/crashlytics_fn.py @@ -15,15 +15,18 @@ """ Cloud functions to handle Crashlytics events from Firebase Alerts. """ + import dataclasses as _dataclasses -import typing as _typing -import cloudevents.http as _ce import datetime as _dt import functools as _functools +import typing as _typing + +import cloudevents.http as _ce + +import firebase_functions.private.util as _util from firebase_functions.alerts import FirebaseAlertData -from firebase_functions.core import T, CloudEvent +from firebase_functions.core import CloudEvent, T from firebase_functions.options import CrashlyticsOptions -import firebase_functions.private.util as _util @_dataclasses.dataclass(frozen=True) @@ -220,8 +223,7 @@ class CrashlyticsEvent(CloudEvent[FirebaseAlertData[T]]): The type of the event for 'on_new_fatal_issue_published' functions. """ -OnNewFatalIssuePublishedCallable = _typing.Callable[ - [CrashlyticsNewFatalIssueEvent], None] +OnNewFatalIssuePublishedCallable = _typing.Callable[[CrashlyticsNewFatalIssueEvent], None] """ The type of the callable for 'on_new_fatal_issue_published' functions. """ @@ -231,8 +233,7 @@ class CrashlyticsEvent(CloudEvent[FirebaseAlertData[T]]): The type of the event for 'on_new_nonfatal_issue_published' functions. """ -OnNewNonfatalIssuePublishedCallable = _typing.Callable[ - [CrashlyticsNewNonfatalIssueEvent], None] +OnNewNonfatalIssuePublishedCallable = _typing.Callable[[CrashlyticsNewNonfatalIssueEvent], None] """ The type of the callable for 'on_new_nonfatal_issue_published' functions. """ @@ -242,8 +243,7 @@ class CrashlyticsEvent(CloudEvent[FirebaseAlertData[T]]): The type of the event for 'on_regression_alert_published' functions. """ -OnRegressionAlertPublishedCallable = _typing.Callable[ - [CrashlyticsRegressionAlertEvent], None] +OnRegressionAlertPublishedCallable = _typing.Callable[[CrashlyticsRegressionAlertEvent], None] """ The type of the callable for 'on_regression_alert_published' functions. """ @@ -253,8 +253,7 @@ class CrashlyticsEvent(CloudEvent[FirebaseAlertData[T]]): The type of the event for 'on_stability_digest_published' functions. """ -OnStabilityDigestPublishedCallable = _typing.Callable[ - [CrashlyticsStabilityDigestEvent], None] +OnStabilityDigestPublishedCallable = _typing.Callable[[CrashlyticsStabilityDigestEvent], None] """ The type of the callable for 'on_stability_digest_published' functions. """ @@ -264,8 +263,7 @@ class CrashlyticsEvent(CloudEvent[FirebaseAlertData[T]]): The type of the event for 'on_velocity_alert_published' functions. """ -OnVelocityAlertPublishedCallable = _typing.Callable[ - [CrashlyticsVelocityAlertEvent], None] +OnVelocityAlertPublishedCallable = _typing.Callable[[CrashlyticsVelocityAlertEvent], None] """ The type of the callable for 'on_velocity_alert_published' functions. """ @@ -275,8 +273,7 @@ class CrashlyticsEvent(CloudEvent[FirebaseAlertData[T]]): The type of the event for 'on_new_anr_issue_published' functions. """ -OnNewAnrIssuePublishedCallable = _typing.Callable[[CrashlyticsNewAnrIssueEvent], - None] +OnNewAnrIssuePublishedCallable = _typing.Callable[[CrashlyticsNewAnrIssueEvent], None] """ The type of the callable for 'on_new_anr_issue_published' functions. """ @@ -289,10 +286,10 @@ def _create_crashlytics_decorator( options = CrashlyticsOptions(**kwargs) def crashlytics_decorator_inner(func: _typing.Callable): - @_functools.wraps(func) def crashlytics_decorator_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import crashlytics_event_from_ce + func(crashlytics_event_from_ce(raw)) _util.set_func_endpoint_attr( @@ -309,9 +306,8 @@ def crashlytics_decorator_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(CrashlyticsOptions) def on_new_fatal_issue_published( - **kwargs -) -> _typing.Callable[[OnNewFatalIssuePublishedCallable], - OnNewFatalIssuePublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnNewFatalIssuePublishedCallable], OnNewFatalIssuePublishedCallable]: """ Event handler which runs every time a new fatal issue is received. @@ -334,14 +330,13 @@ def example(alert: crashlytics_fn.CrashlyticsNewFatalIssueEvent) -> None: \\] A function that takes a CrashlyticsNewFatalIssueEvent and returns None. """ - return _create_crashlytics_decorator('crashlytics.newFatalIssue', **kwargs) + return _create_crashlytics_decorator("crashlytics.newFatalIssue", **kwargs) @_util.copy_func_kwargs(CrashlyticsOptions) def on_new_nonfatal_issue_published( - **kwargs -) -> _typing.Callable[[OnNewNonfatalIssuePublishedCallable], - OnNewNonfatalIssuePublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnNewNonfatalIssuePublishedCallable], OnNewNonfatalIssuePublishedCallable]: """ Event handler which runs every time a new nonfatal issue is received. @@ -364,15 +359,13 @@ def example(alert: crashlytics_fn.CrashlyticsNewNonfatalIssueEvent) -> None: \\] A function that takes a CrashlyticsNewNonfatalIssueEvent and returns None. """ - return _create_crashlytics_decorator('crashlytics.newNonfatalIssue', - **kwargs) + return _create_crashlytics_decorator("crashlytics.newNonfatalIssue", **kwargs) @_util.copy_func_kwargs(CrashlyticsOptions) def on_regression_alert_published( - **kwargs -) -> _typing.Callable[[OnRegressionAlertPublishedCallable], - OnRegressionAlertPublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnRegressionAlertPublishedCallable], OnRegressionAlertPublishedCallable]: """ Event handler which runs every time a regression alert is received. @@ -395,14 +388,13 @@ def example(alert: crashlytics_fn.CrashlyticsRegressionAlertEvent) -> None: \\] A function that takes a CrashlyticsRegressionAlertEvent and returns None. """ - return _create_crashlytics_decorator('crashlytics.regression', **kwargs) + return _create_crashlytics_decorator("crashlytics.regression", **kwargs) @_util.copy_func_kwargs(CrashlyticsOptions) def on_stability_digest_published( - **kwargs -) -> _typing.Callable[[OnStabilityDigestPublishedCallable], - OnStabilityDigestPublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnStabilityDigestPublishedCallable], OnStabilityDigestPublishedCallable]: """ Event handler which runs every time a stability digest is received. @@ -425,15 +417,13 @@ def example(alert: crashlytics_fn.CrashlyticsStabilityDigestEvent) -> None: \\] A function that takes a CrashlyticsStabilityDigestEvent and returns None. """ - return _create_crashlytics_decorator('crashlytics.stabilityDigest', - **kwargs) + return _create_crashlytics_decorator("crashlytics.stabilityDigest", **kwargs) @_util.copy_func_kwargs(CrashlyticsOptions) def on_velocity_alert_published( - **kwargs -) -> _typing.Callable[[OnVelocityAlertPublishedCallable], - OnVelocityAlertPublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnVelocityAlertPublishedCallable], OnVelocityAlertPublishedCallable]: """ Event handler which runs every time a velocity alert is received. @@ -456,14 +446,13 @@ def example(alert: crashlytics_fn.CrashlyticsVelocityAlertEvent) -> None: \\] A function that takes a CrashlyticsVelocityAlertEvent and returns None. """ - return _create_crashlytics_decorator('crashlytics.velocity', **kwargs) + return _create_crashlytics_decorator("crashlytics.velocity", **kwargs) @_util.copy_func_kwargs(CrashlyticsOptions) def on_new_anr_issue_published( - **kwargs -) -> _typing.Callable[[OnNewAnrIssuePublishedCallable], - OnNewAnrIssuePublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnNewAnrIssuePublishedCallable], OnNewAnrIssuePublishedCallable]: """ Event handler which runs every time a new ANR issue is received. @@ -486,4 +475,4 @@ def example(alert: crashlytics_fn.CrashlyticsNewAnrIssueEvent) -> None: \\] A function that takes a CrashlyticsNewAnrIssueEvent and returns None. """ - return _create_crashlytics_decorator('crashlytics.newAnrIssue', **kwargs) + return _create_crashlytics_decorator("crashlytics.newAnrIssue", **kwargs) diff --git a/src/firebase_functions/alerts/performance_fn.py b/src/firebase_functions/alerts/performance_fn.py index ad00c6ab..6f6759ed 100644 --- a/src/firebase_functions/alerts/performance_fn.py +++ b/src/firebase_functions/alerts/performance_fn.py @@ -19,12 +19,12 @@ import dataclasses as _dataclasses import functools as _functools import typing as _typing + import cloudevents.http as _ce -from firebase_functions.alerts import FirebaseAlertData import firebase_functions.private.util as _util - -from firebase_functions.core import T, CloudEvent +from firebase_functions.alerts import FirebaseAlertData +from firebase_functions.core import CloudEvent, T from firebase_functions.options import PerformanceOptions @@ -37,19 +37,19 @@ class ThresholdAlertPayload: event_name: str """ - Name of the trace or network request this alert is for + Name of the trace or network request this alert is for (e.g. my_custom_trace, firebase.com/api/123). """ event_type: str """ - The resource type this alert is for (i.e. trace, network request, + The resource type this alert is for (i.e. trace, network request, screen rendering, etc.). """ metric_type: str """ - The metric type this alert is for (i.e. success rate, + The metric type this alert is for (i.e. success rate, response time, duration, etc.). """ @@ -85,15 +85,15 @@ class ThresholdAlertPayload: condition_percentile: float | int | None = None """ - The percentile of the alert condition, can be 0 if percentile + The percentile of the alert condition, can be 0 if percentile is not applicable to the alert condition and omitted; range: [1, 100]. """ app_version: str | None = None """ - The app version this alert was triggered for, can be omitted - if the alert is for a network request (because the alert was + The app version this alert was triggered for, can be omitted + if the alert is for a network request (because the alert was checked against data from all versions of app) or a web app (where the app is versionless). """ @@ -121,8 +121,7 @@ class PerformanceEvent(CloudEvent[FirebaseAlertData[T]]): The type of the event for 'on_threshold_alert_published' functions. """ -OnThresholdAlertPublishedCallable = _typing.Callable[ - [PerformanceThresholdAlertEvent], None] +OnThresholdAlertPublishedCallable = _typing.Callable[[PerformanceThresholdAlertEvent], None] """ The type of the callable for 'on_threshold_alert_published' functions. """ @@ -130,9 +129,8 @@ class PerformanceEvent(CloudEvent[FirebaseAlertData[T]]): @_util.copy_func_kwargs(PerformanceOptions) def on_threshold_alert_published( - **kwargs -) -> _typing.Callable[[OnThresholdAlertPublishedCallable], - OnThresholdAlertPublishedCallable]: + **kwargs, +) -> _typing.Callable[[OnThresholdAlertPublishedCallable], OnThresholdAlertPublishedCallable]: """ Event handler which runs every time a threshold alert is received. @@ -140,7 +138,7 @@ def on_threshold_alert_published( .. code-block:: python - import firebase_functions.alerts.performance_fn as performance_fn + import firebase_functions.alerts.performance_fn as performance_fn @performance_fn.on_threshold_alert_published() def example(alert: performance_fn.PerformanceThresholdAlertEvent) -> None: @@ -149,27 +147,26 @@ def example(alert: performance_fn.PerformanceThresholdAlertEvent) -> None: :param \\*\\*kwargs: Options. :type \\*\\*kwargs: as :exc:`firebase_functions.options.PerformanceOptions` :rtype: :exc:`typing.Callable` - \\[ - \\[ :exc:`firebase_functions.alerts.performance_fn.PerformanceThresholdAlertEvent` \\], - `None` + \\[ + \\[ :exc:`firebase_functions.alerts.performance_fn.PerformanceThresholdAlertEvent` \\], + `None` \\] A function that takes a PerformanceThresholdAlertEvent and returns None. """ options = PerformanceOptions(**kwargs) - def on_threshold_alert_published_inner_decorator( - func: OnThresholdAlertPublishedCallable): - + def on_threshold_alert_published_inner_decorator(func: OnThresholdAlertPublishedCallable): @_functools.wraps(func) def on_threshold_alert_published_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import performance_event_from_ce + func(performance_event_from_ce(raw)) _util.set_func_endpoint_attr( on_threshold_alert_published_wrapped, options._endpoint( func_name=func.__name__, - alert_type='performance.threshold', + alert_type="performance.threshold", ), ) return on_threshold_alert_published_wrapped diff --git a/src/firebase_functions/alerts_fn.py b/src/firebase_functions/alerts_fn.py index ed686738..4deff617 100644 --- a/src/firebase_functions/alerts_fn.py +++ b/src/firebase_functions/alerts_fn.py @@ -19,17 +19,20 @@ import dataclasses as _dataclasses import functools as _functools import typing as _typing + import cloudevents.http as _ce -from firebase_functions.alerts import FirebaseAlertData import firebase_functions.private.util as _util +from firebase_functions.alerts import FirebaseAlertData +from firebase_functions.core import CloudEvent as _CloudEvent +from firebase_functions.core import T, _with_init -from firebase_functions.core import T, CloudEvent as _CloudEvent, _with_init -from firebase_functions.options import FirebaseAlertOptions - -# Explicitly import AlertType to make it available in the public API. -# pylint: disable=unused-import -from firebase_functions.options import AlertType +# Re-export AlertType from options module so users can import it directly from alerts_fn +# This provides a more convenient API: from firebase_functions.alerts_fn import AlertType +from firebase_functions.options import ( + AlertType, # noqa: F401 + FirebaseAlertOptions, +) @_dataclasses.dataclass(frozen=True) @@ -63,7 +66,7 @@ class AlertEvent(_CloudEvent[T]): @_util.copy_func_kwargs(FirebaseAlertOptions) def on_alert_published( - **kwargs + **kwargs, ) -> _typing.Callable[[OnAlertPublishedCallable], OnAlertPublishedCallable]: """ Event handler that triggers when a Firebase Alerts event is published. @@ -72,7 +75,7 @@ def on_alert_published( .. code-block:: python - from firebase_functions import alerts_fn + from firebase_functions import alerts_fn @alerts_fn.on_alert_published( alert_type=alerts_fn.AlertType.CRASHLYTICS_NEW_FATAL_ISSUE, @@ -91,10 +94,10 @@ def example(alert: alerts_fn.AlertEvent[alerts_fn.FirebaseAlertData]) -> None: options = FirebaseAlertOptions(**kwargs) def on_alert_published_inner_decorator(func: OnAlertPublishedCallable): - @_functools.wraps(func) def on_alert_published_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import alerts_event_from_ce + _with_init(func)(alerts_event_from_ce(raw)) _util.set_func_endpoint_attr( diff --git a/src/firebase_functions/core.py b/src/firebase_functions/core.py index a12e6889..fb8dfd23 100644 --- a/src/firebase_functions/core.py +++ b/src/firebase_functions/core.py @@ -14,6 +14,7 @@ """ Public code that is shared across modules. """ + import dataclasses as _dataclass import datetime as _datetime import typing as _typing @@ -90,9 +91,9 @@ class Change(_typing.Generic[T]): def init(callback: _typing.Callable[[], _typing.Any]) -> None: """ - Registers a function that should be run when in a production environment - before executing any functions code. - Calling this decorator more than once leads to undefined behavior. + Registers a function that should be run when in a production environment + before executing any functions code. + Calling this decorator more than once leads to undefined behavior. """ global _did_init @@ -107,9 +108,7 @@ def init(callback: _typing.Callable[[], _typing.Any]) -> None: _did_init = False -def _with_init( - fn: _typing.Callable[..., - _typing.Any]) -> _typing.Callable[..., _typing.Any]: +def _with_init(fn: _typing.Callable[..., _typing.Any]) -> _typing.Callable[..., _typing.Any]: """ A decorator that runs the init callback before running the decorated function. """ diff --git a/src/firebase_functions/db_fn.py b/src/firebase_functions/db_fn.py index 7298e994..db98f817 100644 --- a/src/firebase_functions/db_fn.py +++ b/src/firebase_functions/db_fn.py @@ -14,18 +14,20 @@ """ Module for Cloud Functions that are triggered by the Firebase Realtime Database. """ + # pylint: disable=protected-access import dataclasses as _dataclass +import datetime as _dt import functools as _functools import typing as _typing -import datetime as _dt -import firebase_functions.private.util as _util -import firebase_functions.private.path_pattern as _path_pattern -import firebase_functions.core as _core + import cloudevents.http as _ce -from firebase_functions.options import DatabaseOptions +import firebase_functions.core as _core +import firebase_functions.private.path_pattern as _path_pattern +import firebase_functions.private.util as _util from firebase_functions.core import Change, T +from firebase_functions.options import DatabaseOptions _event_type_written = "google.firebase.database.ref.v1.written" _event_type_created = "google.firebase.database.ref.v1.created" @@ -147,7 +149,8 @@ def example(event: Event[Change[object]]) -> None: def on_value_written_inner_decorator(func: _C1): ref_pattern = _path_pattern.PathPattern(options.reference) instance_pattern = _path_pattern.PathPattern( - options.instance if options.instance is not None else "*") + options.instance if options.instance is not None else "*" + ) @_functools.wraps(func) def on_value_written_wrapped(raw: _ce.CloudEvent): @@ -197,7 +200,8 @@ def example(event: Event[Change[object]]) -> None: def on_value_updated_inner_decorator(func: _C1): ref_pattern = _path_pattern.PathPattern(options.reference) instance_pattern = _path_pattern.PathPattern( - options.instance if options.instance is not None else "*") + options.instance if options.instance is not None else "*" + ) @_functools.wraps(func) def on_value_updated_wrapped(raw: _ce.CloudEvent): @@ -247,7 +251,8 @@ def example(event: Event[object]): def on_value_created_inner_decorator(func: _C2): ref_pattern = _path_pattern.PathPattern(options.reference) instance_pattern = _path_pattern.PathPattern( - options.instance if options.instance is not None else "*") + options.instance if options.instance is not None else "*" + ) @_functools.wraps(func) def on_value_created_wrapped(raw: _ce.CloudEvent): @@ -297,7 +302,8 @@ def example(event: Event[object]) -> None: def on_value_deleted_inner_decorator(func: _C2): ref_pattern = _path_pattern.PathPattern(options.reference) instance_pattern = _path_pattern.PathPattern( - options.instance if options.instance is not None else "*") + options.instance if options.instance is not None else "*" + ) @_functools.wraps(func) def on_value_deleted_wrapped(raw: _ce.CloudEvent): diff --git a/src/firebase_functions/eventarc_fn.py b/src/firebase_functions/eventarc_fn.py index d76772c2..0e23549e 100644 --- a/src/firebase_functions/eventarc_fn.py +++ b/src/firebase_functions/eventarc_fn.py @@ -14,9 +14,10 @@ """Cloud functions to handle Eventarc events.""" # pylint: disable=protected-access -import typing as _typing -import functools as _functools import datetime as _dt +import functools as _functools +import typing as _typing + import cloudevents.http as _ce import firebase_functions.options as _options @@ -26,9 +27,8 @@ @_util.copy_func_kwargs(_options.EventarcTriggerOptions) def on_custom_event_published( - **kwargs -) -> _typing.Callable[[_typing.Callable[[CloudEvent], None]], _typing.Callable[ - [CloudEvent], None]]: + **kwargs, +) -> _typing.Callable[[_typing.Callable[[CloudEvent], None]], _typing.Callable[[CloudEvent], None]]: """ Creates a handler for events published on the default event eventarc channel. @@ -52,9 +52,7 @@ def onimageresize(event: eventarc_fn.CloudEvent) -> None: """ options = _options.EventarcTriggerOptions(**kwargs) - def on_custom_event_published_decorator(func: _typing.Callable[[CloudEvent], - None]): - + def on_custom_event_published_decorator(func: _typing.Callable[[CloudEvent], None]): @_functools.wraps(func) def on_custom_event_published_wrapped(raw: _ce.CloudEvent): event_attributes = raw._get_attributes() @@ -65,8 +63,7 @@ def on_custom_event_published_wrapped(raw: _ce.CloudEvent): id=event_dict["id"], source=event_dict["source"], specversion=event_dict["specversion"], - subject=event_dict["subject"] - if "subject" in event_dict else None, + subject=event_dict["subject"] if "subject" in event_dict else None, time=_dt.datetime.strptime( event_dict["time"], "%Y-%m-%dT%H:%M:%S.%f%z", diff --git a/src/firebase_functions/firestore_fn.py b/src/firebase_functions/firestore_fn.py index a9d4f2a6..22fd79f8 100644 --- a/src/firebase_functions/firestore_fn.py +++ b/src/firebase_functions/firestore_fn.py @@ -14,24 +14,25 @@ """ Module for Cloud Functions that are triggered by Firestore. """ + # pylint: disable=protected-access import dataclasses as _dataclass import functools as _functools import typing as _typing -import google.events.cloud.firestore as _firestore -import google.cloud.firestore_v1 as _firestore_v1 -import firebase_functions.private.util as _util -import firebase_functions.private.path_pattern as _path_pattern -import firebase_functions.core as _core -import cloudevents.http as _ce -from firebase_admin import initialize_app, get_app, _apps, _DEFAULT_APP_NAME +import cloudevents.http as _ce +import google.cloud.firestore_v1 as _firestore_v1 +import google.events.cloud.firestore as _firestore +from firebase_admin import _DEFAULT_APP_NAME, _apps, get_app, initialize_app from google.cloud._helpers import _datetime_to_pb_timestamp +from google.cloud.firestore_v1 import DocumentReference, DocumentSnapshot from google.cloud.firestore_v1 import _helpers as _firestore_helpers -from google.cloud.firestore_v1 import DocumentSnapshot, DocumentReference -from firebase_functions.options import FirestoreOptions +import firebase_functions.core as _core +import firebase_functions.private.path_pattern as _path_pattern +import firebase_functions.private.util as _util from firebase_functions.core import Change +from firebase_functions.options import FirestoreOptions _event_type_written = "google.cloud.firestore.document.v1.written" _event_type_created = "google.cloud.firestore.document.v1.created" @@ -87,8 +88,7 @@ class Event(_core.CloudEvent[_core.T]): _C1 = _typing.Callable[[_E1], None] _C2 = _typing.Callable[[_E2], None] -AuthType = _typing.Literal["service_account", "api_key", "system", - "unauthenticated", "unknown"] +AuthType = _typing.Literal["service_account", "api_key", "system", "unauthenticated", "unknown"] @_dataclass.dataclass(frozen=True) @@ -117,17 +117,18 @@ def _firestore_endpoint_handler( content_type: str = event_attributes["datacontenttype"] if "application/json" in content_type or isinstance(event_data, dict): firestore_event_data = _typing.cast( - _firestore.DocumentEventData, - _firestore.DocumentEventData.from_json(event_data)) - elif "application/protobuf" in content_type or isinstance( - event_data, bytes): + _firestore.DocumentEventData, _firestore.DocumentEventData.from_json(event_data) + ) + elif "application/protobuf" in content_type or isinstance(event_data, bytes): firestore_event_data = _typing.cast( - _firestore.DocumentEventData, - _firestore.DocumentEventData.deserialize(event_data)) + _firestore.DocumentEventData, _firestore.DocumentEventData.deserialize(event_data) + ) else: actual_type = type(event_data) - raise TypeError(f"Firestore: Cannot parse event payload of data type " - f"'{actual_type}' and content type '{content_type}'.") + raise TypeError( + f"Firestore: Cannot parse event payload of data type " + f"'{actual_type}' and content type '{content_type}'." + ) event_location = event_attributes["location"] event_project = event_attributes["project"] @@ -141,15 +142,15 @@ def _firestore_endpoint_handler( if _DEFAULT_APP_NAME not in _apps: initialize_app() app = get_app() - firestore_client = _firestore_v1.Client(project=app.project_id, - database=event_database) + firestore_client = _firestore_v1.Client(project=app.project_id, database=event_database) firestore_ref: DocumentReference = firestore_client.document(event_document) value_snapshot: DocumentSnapshot | None = None old_value_snapshot: DocumentSnapshot | None = None if firestore_event_data.value: document_dict = _firestore_helpers.decode_dict( - firestore_event_data.value.fields, firestore_client) + firestore_event_data.value.fields, firestore_client + ) value_snapshot = _firestore_v1.DocumentSnapshot( firestore_ref, document_dict, @@ -160,7 +161,8 @@ def _firestore_endpoint_handler( ) if firestore_event_data.old_value: document_dict = _firestore_helpers.decode_dict( - firestore_event_data.old_value.fields, firestore_client) + firestore_event_data.old_value.fields, firestore_client + ) old_value_snapshot = _firestore_v1.DocumentSnapshot( firestore_ref, document_dict, @@ -170,23 +172,23 @@ def _firestore_endpoint_handler( firestore_event_data.old_value.update_time, ) - if event_type in (_event_type_deleted, - _event_type_deleted_with_auth_context): - firestore_event_data = _typing.cast(_firestore.DocumentEventData, - old_value_snapshot) - if event_type in (_event_type_created, - _event_type_created_with_auth_context): - firestore_event_data = _typing.cast(_firestore.DocumentEventData, - value_snapshot) - if event_type in (_event_type_written, _event_type_updated, - _event_type_written_with_auth_context, - _event_type_updated_with_auth_context): + if event_type in (_event_type_deleted, _event_type_deleted_with_auth_context): + firestore_event_data = _typing.cast(_firestore.DocumentEventData, old_value_snapshot) + if event_type in (_event_type_created, _event_type_created_with_auth_context): + firestore_event_data = _typing.cast(_firestore.DocumentEventData, value_snapshot) + if event_type in ( + _event_type_written, + _event_type_updated, + _event_type_written_with_auth_context, + _event_type_updated_with_auth_context, + ): firestore_event_data = _typing.cast( _firestore.DocumentEventData, Change( before=old_value_snapshot, after=value_snapshot, - )) + ), + ) params: dict[str, str] = { **document_pattern.extract_matches(event_document), @@ -213,9 +215,9 @@ def _firestore_endpoint_handler( if event_type.endswith(".withAuthContext"): event_auth_type = event_attributes["authtype"] event_auth_id = event_attributes["authid"] - database_event_with_auth_context = AuthEvent(**vars(database_event), - auth_type=event_auth_type, - auth_id=event_auth_id) + database_event_with_auth_context = AuthEvent( + **vars(database_event), auth_type=event_auth_type, auth_id=event_auth_id + ) func(database_event_with_auth_context) else: # mypy cannot infer that the event type is correct, hence the cast @@ -245,8 +247,7 @@ def example(event: Event[Change[DocumentSnapshot]]) -> None: options = FirestoreOptions(**kwargs) def on_document_written_inner_decorator(func: _C1): - document_pattern = _path_pattern.PathPattern( - _util.normalize_path(options.document)) + document_pattern = _path_pattern.PathPattern(_util.normalize_path(options.document)) @_functools.wraps(func) def on_document_written_wrapped(raw: _ce.CloudEvent): @@ -271,8 +272,7 @@ def on_document_written_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(FirestoreOptions) -def on_document_written_with_auth_context(**kwargs - ) -> _typing.Callable[[_C1], _C1]: +def on_document_written_with_auth_context(**kwargs) -> _typing.Callable[[_C1], _C1]: """ Event handler that triggers when a document is created, updated, or deleted in Firestore. This trigger will also provide the authentication context of the principal who triggered @@ -296,8 +296,7 @@ def example(event: AuthEvent[Change[DocumentSnapshot]]) -> None: options = FirestoreOptions(**kwargs) def on_document_written_with_auth_context_inner_decorator(func: _C1): - document_pattern = _path_pattern.PathPattern( - _util.normalize_path(options.document)) + document_pattern = _path_pattern.PathPattern(_util.normalize_path(options.document)) @_functools.wraps(func) def on_document_written_with_auth_context_wrapped(raw: _ce.CloudEvent): @@ -344,8 +343,7 @@ def example(event: Event[Change[DocumentSnapshot]]) -> None: options = FirestoreOptions(**kwargs) def on_document_updated_inner_decorator(func: _C1): - document_pattern = _path_pattern.PathPattern( - _util.normalize_path(options.document)) + document_pattern = _path_pattern.PathPattern(_util.normalize_path(options.document)) @_functools.wraps(func) def on_document_updated_wrapped(raw: _ce.CloudEvent): @@ -370,8 +368,7 @@ def on_document_updated_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(FirestoreOptions) -def on_document_updated_with_auth_context(**kwargs - ) -> _typing.Callable[[_C1], _C1]: +def on_document_updated_with_auth_context(**kwargs) -> _typing.Callable[[_C1], _C1]: """ Event handler that triggers when a document is updated in Firestore. This trigger will also provide the authentication context of the principal who triggered @@ -395,8 +392,7 @@ def example(event: AuthEvent[Change[DocumentSnapshot]]) -> None: options = FirestoreOptions(**kwargs) def on_document_updated_with_auth_context_inner_decorator(func: _C1): - document_pattern = _path_pattern.PathPattern( - _util.normalize_path(options.document)) + document_pattern = _path_pattern.PathPattern(_util.normalize_path(options.document)) @_functools.wraps(func) def on_document_updated_with_auth_context_wrapped(raw: _ce.CloudEvent): @@ -443,8 +439,7 @@ def example(event: Event[DocumentSnapshot]): options = FirestoreOptions(**kwargs) def on_document_created_inner_decorator(func: _C2): - document_pattern = _path_pattern.PathPattern( - _util.normalize_path(options.document)) + document_pattern = _path_pattern.PathPattern(_util.normalize_path(options.document)) @_functools.wraps(func) def on_document_created_wrapped(raw: _ce.CloudEvent): @@ -469,8 +464,7 @@ def on_document_created_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(FirestoreOptions) -def on_document_created_with_auth_context(**kwargs - ) -> _typing.Callable[[_C2], _C2]: +def on_document_created_with_auth_context(**kwargs) -> _typing.Callable[[_C2], _C2]: """ Event handler that triggers when a document is created in Firestore. This trigger will also provide the authentication context of the principal who triggered @@ -494,8 +488,7 @@ def example(event: AuthEvent[DocumentSnapshot]): options = FirestoreOptions(**kwargs) def on_document_created_with_auth_context_inner_decorator(func: _C2): - document_pattern = _path_pattern.PathPattern( - _util.normalize_path(options.document)) + document_pattern = _path_pattern.PathPattern(_util.normalize_path(options.document)) @_functools.wraps(func) def on_document_created_with_auth_context_wrapped(raw: _ce.CloudEvent): @@ -542,8 +535,7 @@ def example(event: Event[DocumentSnapshot]) -> None: options = FirestoreOptions(**kwargs) def on_document_deleted_inner_decorator(func: _C2): - document_pattern = _path_pattern.PathPattern( - _util.normalize_path(options.document)) + document_pattern = _path_pattern.PathPattern(_util.normalize_path(options.document)) @_functools.wraps(func) def on_document_deleted_wrapped(raw: _ce.CloudEvent): @@ -568,8 +560,7 @@ def on_document_deleted_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(FirestoreOptions) -def on_document_deleted_with_auth_context(**kwargs - ) -> _typing.Callable[[_C2], _C2]: +def on_document_deleted_with_auth_context(**kwargs) -> _typing.Callable[[_C2], _C2]: """ Event handler that triggers when a document is deleted in Firestore. This trigger will also provide the authentication context of the principal who triggered @@ -593,8 +584,7 @@ def example(event: AuthEvent[DocumentSnapshot]) -> None: options = FirestoreOptions(**kwargs) def on_document_deleted_with_auth_context_inner_decorator(func: _C2): - document_pattern = _path_pattern.PathPattern( - _util.normalize_path(options.document)) + document_pattern = _path_pattern.PathPattern(_util.normalize_path(options.document)) @_functools.wraps(func) def on_document_deleted_with_auth_context_wrapped(raw: _ce.CloudEvent): diff --git a/src/firebase_functions/https_fn.py b/src/firebase_functions/https_fn.py index 10749e9d..7e692de0 100644 --- a/src/firebase_functions/https_fn.py +++ b/src/firebase_functions/https_fn.py @@ -14,20 +14,24 @@ """Module for functions that listen to HTTPS endpoints. These can be raw web requests and Callable RPCs. """ + # pylint: disable=protected-access import dataclasses as _dataclasses +import enum as _enum import functools as _functools +import json as _json import typing as _typing + import typing_extensions as _typing_extensions -import enum as _enum -import json as _json -import firebase_functions.private.util as _util -import firebase_functions.core as _core +from flask import Request, Response +from flask import jsonify as _jsonify +from flask import make_response as _make_response +from flask_cors import cross_origin as _cross_origin from functions_framework import logging as _logging -from firebase_functions.options import HttpsOptions, _GLOBAL_OPTIONS -from flask import Request, Response, make_response as _make_response, jsonify as _jsonify -from flask_cors import cross_origin as _cross_origin +import firebase_functions.core as _core +import firebase_functions.private.util as _util +from firebase_functions.options import _GLOBAL_OPTIONS, HttpsOptions class FunctionsErrorCode(str, _enum.Enum): @@ -176,40 +180,35 @@ class _HttpErrorCode: _error_code_map = { - FunctionsErrorCode.OK: - _HttpErrorCode(_CanonicalErrorCodeName.OK, 200), - FunctionsErrorCode.CANCELLED: - _HttpErrorCode(_CanonicalErrorCodeName.CANCELLED, 499), - FunctionsErrorCode.UNKNOWN: - _HttpErrorCode(_CanonicalErrorCodeName.UNKNOWN, 500), - FunctionsErrorCode.INVALID_ARGUMENT: - _HttpErrorCode(_CanonicalErrorCodeName.INVALID_ARGUMENT, 400), - FunctionsErrorCode.DEADLINE_EXCEEDED: - _HttpErrorCode(_CanonicalErrorCodeName.DEADLINE_EXCEEDED, 504), - FunctionsErrorCode.NOT_FOUND: - _HttpErrorCode(_CanonicalErrorCodeName.NOT_FOUND, 404), - FunctionsErrorCode.ALREADY_EXISTS: - _HttpErrorCode(_CanonicalErrorCodeName.ALREADY_EXISTS, 409), - FunctionsErrorCode.PERMISSION_DENIED: - _HttpErrorCode(_CanonicalErrorCodeName.PERMISSION_DENIED, 403), - FunctionsErrorCode.UNAUTHENTICATED: - _HttpErrorCode(_CanonicalErrorCodeName.UNAUTHENTICATED, 401), - FunctionsErrorCode.RESOURCE_EXHAUSTED: - _HttpErrorCode(_CanonicalErrorCodeName.RESOURCE_EXHAUSTED, 429), - FunctionsErrorCode.FAILED_PRECONDITION: - _HttpErrorCode(_CanonicalErrorCodeName.FAILED_PRECONDITION, 400), - FunctionsErrorCode.ABORTED: - _HttpErrorCode(_CanonicalErrorCodeName.ABORTED, 409), - FunctionsErrorCode.OUT_OF_RANGE: - _HttpErrorCode(_CanonicalErrorCodeName.OUT_OF_RANGE, 400), - FunctionsErrorCode.UNIMPLEMENTED: - _HttpErrorCode(_CanonicalErrorCodeName.UNIMPLEMENTED, 501), - FunctionsErrorCode.INTERNAL: - _HttpErrorCode(_CanonicalErrorCodeName.INTERNAL, 500), - FunctionsErrorCode.UNAVAILABLE: - _HttpErrorCode(_CanonicalErrorCodeName.UNAVAILABLE, 503), - FunctionsErrorCode.DATA_LOSS: - _HttpErrorCode(_CanonicalErrorCodeName.DATA_LOSS, 500), + FunctionsErrorCode.OK: _HttpErrorCode(_CanonicalErrorCodeName.OK, 200), + FunctionsErrorCode.CANCELLED: _HttpErrorCode(_CanonicalErrorCodeName.CANCELLED, 499), + FunctionsErrorCode.UNKNOWN: _HttpErrorCode(_CanonicalErrorCodeName.UNKNOWN, 500), + FunctionsErrorCode.INVALID_ARGUMENT: _HttpErrorCode( + _CanonicalErrorCodeName.INVALID_ARGUMENT, 400 + ), + FunctionsErrorCode.DEADLINE_EXCEEDED: _HttpErrorCode( + _CanonicalErrorCodeName.DEADLINE_EXCEEDED, 504 + ), + FunctionsErrorCode.NOT_FOUND: _HttpErrorCode(_CanonicalErrorCodeName.NOT_FOUND, 404), + FunctionsErrorCode.ALREADY_EXISTS: _HttpErrorCode(_CanonicalErrorCodeName.ALREADY_EXISTS, 409), + FunctionsErrorCode.PERMISSION_DENIED: _HttpErrorCode( + _CanonicalErrorCodeName.PERMISSION_DENIED, 403 + ), + FunctionsErrorCode.UNAUTHENTICATED: _HttpErrorCode( + _CanonicalErrorCodeName.UNAUTHENTICATED, 401 + ), + FunctionsErrorCode.RESOURCE_EXHAUSTED: _HttpErrorCode( + _CanonicalErrorCodeName.RESOURCE_EXHAUSTED, 429 + ), + FunctionsErrorCode.FAILED_PRECONDITION: _HttpErrorCode( + _CanonicalErrorCodeName.FAILED_PRECONDITION, 400 + ), + FunctionsErrorCode.ABORTED: _HttpErrorCode(_CanonicalErrorCodeName.ABORTED, 409), + FunctionsErrorCode.OUT_OF_RANGE: _HttpErrorCode(_CanonicalErrorCodeName.OUT_OF_RANGE, 400), + FunctionsErrorCode.UNIMPLEMENTED: _HttpErrorCode(_CanonicalErrorCodeName.UNIMPLEMENTED, 501), + FunctionsErrorCode.INTERNAL: _HttpErrorCode(_CanonicalErrorCodeName.INTERNAL, 500), + FunctionsErrorCode.UNAVAILABLE: _HttpErrorCode(_CanonicalErrorCodeName.UNAVAILABLE, 503), + FunctionsErrorCode.DATA_LOSS: _HttpErrorCode(_CanonicalErrorCodeName.DATA_LOSS, 500), } """ Standard error codes and HTTP statuses for different ways a request can fail, @@ -352,8 +351,7 @@ class CallableRequest(_typing.Generic[_core.T]): _C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] -def _on_call_handler(func: _C2, request: Request, - enforce_app_check: bool) -> Response: +def _on_call_handler(func: _C2, request: Request, enforce_app_check: bool) -> Response: try: if not _util.valid_on_call_request(request): _logging.error("Invalid request, unable to process.") @@ -366,27 +364,26 @@ def _on_call_handler(func: _C2, request: Request, token_status = _util.on_call_check_tokens(request) if token_status.auth == _util.OnCallTokenState.INVALID: - raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, - "Unauthenticated") + raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, "Unauthenticated") if enforce_app_check and token_status.app in ( - _util.OnCallTokenState.MISSING, _util.OnCallTokenState.INVALID): - raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, - "Unauthenticated") + _util.OnCallTokenState.MISSING, + _util.OnCallTokenState.INVALID, + ): + raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, "Unauthenticated") if token_status.app == _util.OnCallTokenState.VALID and token_status.app_token is not None: context = _dataclasses.replace( context, - app=AppCheckData(token_status.app_token["sub"], - token_status.app_token), + app=AppCheckData(token_status.app_token["sub"], token_status.app_token), ) if token_status.auth_token is not None: context = _dataclasses.replace( context, auth=AuthData( - token_status.auth_token["uid"] - if "uid" in token_status.auth_token else None, - token_status.auth_token), + token_status.auth_token["uid"] if "uid" in token_status.auth_token else None, + token_status.auth_token, + ), ) instance_id = request.headers.get("Firebase-Instance-ID-Token") @@ -397,8 +394,7 @@ def _on_call_handler(func: _C2, request: Request, # pushes with FCM. In that case, the FCM APIs will validate the token. context = _dataclasses.replace( context, - instance_id_token=request.headers.get( - "Firebase-Instance-ID-Token"), + instance_id_token=request.headers.get("Firebase-Instance-ID-Token"), ) result = _core._with_init(func)(context) return _jsonify(result=result) @@ -436,7 +432,6 @@ def example(request: Request) -> Response: options = HttpsOptions(**kwargs) def on_request_inner_decorator(func: _C1): - @_functools.wraps(func) def on_request_wrapped(request: Request) -> Response: if options.cors is not None: diff --git a/src/firebase_functions/identity_fn.py b/src/firebase_functions/identity_fn.py index dfe4e2a9..d2afead1 100644 --- a/src/firebase_functions/identity_fn.py +++ b/src/firebase_functions/identity_fn.py @@ -14,25 +14,29 @@ """Cloud functions to handle Eventarc events.""" # pylint: disable=protected-access,cyclic-import -import typing as _typing -import functools as _functools -import datetime as _dt import dataclasses as _dataclasses +import datetime as _dt +import functools as _functools +import typing as _typing from enum import Enum -import firebase_functions.options as _options -import firebase_functions.private.util as _util from flask import ( Request as _Request, +) +from flask import ( Response as _Response, ) +import firebase_functions.options as _options +import firebase_functions.private.util as _util + @_dataclasses.dataclass(frozen=True) class AuthUserInfo: """ User info that is part of the AuthUserRecord. """ + uid: str """The user identifier for the linked provider.""" @@ -57,10 +61,11 @@ class AuthUserMetadata: """ Additional metadata about the user. """ + creation_time: _dt.datetime """The date the user was created.""" - last_sign_in_time: _typing.Optional[_dt.datetime] + last_sign_in_time: _dt.datetime | None """The date the user last signed in.""" @@ -348,14 +353,12 @@ class BeforeSignInResponse(BeforeCreateResponse, total=False): """The user's session claims object if available.""" -BeforeUserCreatedCallable = _typing.Callable[[AuthBlockingEvent], - BeforeCreateResponse | None] +BeforeUserCreatedCallable = _typing.Callable[[AuthBlockingEvent], BeforeCreateResponse | None] """ The type of the callable for 'before_user_created' blocking events. """ -BeforeUserSignedInCallable = _typing.Callable[[AuthBlockingEvent], - BeforeSignInResponse | None] +BeforeUserSignedInCallable = _typing.Callable[[AuthBlockingEvent], BeforeSignInResponse | None] """ The type of the callable for 'before_user_signed_in' blocking events. """ @@ -393,6 +396,7 @@ def before_user_signed_in_decorator(func: BeforeUserSignedInCallable): @_functools.wraps(func) def before_user_signed_in_wrapped(request: _Request) -> _Response: from firebase_functions.private._identity_fn import before_operation_handler + return before_operation_handler( func, event_type_before_sign_in, @@ -447,6 +451,7 @@ def before_user_created_decorator(func: BeforeUserCreatedCallable): @_functools.wraps(func) def before_user_created_wrapped(request: _Request) -> _Response: from firebase_functions.private._identity_fn import before_operation_handler + return before_operation_handler( func, event_type_before_create, diff --git a/src/firebase_functions/logger.py b/src/firebase_functions/logger.py index 62562bb2..0e22fc5e 100644 --- a/src/firebase_functions/logger.py +++ b/src/firebase_functions/logger.py @@ -6,6 +6,7 @@ import json as _json import sys as _sys import typing as _typing + import typing_extensions as _typing_extensions # If encoding is not 'utf-8', change it to 'utf-8'. @@ -49,25 +50,28 @@ def _entry_from_args(severity: LogSeverity, *args, **kwargs) -> LogEntry: Creates a `LogEntry` from the given arguments. """ - message: str = " ".join([ - value if isinstance(value, str) else _json.dumps( - _remove_circular(value), ensure_ascii=False) for value in args - ]) + message: str = " ".join( + [ + value + if isinstance(value, str) + else _json.dumps(_remove_circular(value), ensure_ascii=False) + for value in args + ] + ) - other: _typing.Dict[str, _typing.Any] = { + other: dict[str, _typing.Any] = { key: value if isinstance(value, str) else _remove_circular(value) for key, value in kwargs.items() } - entry: _typing.Dict[str, _typing.Any] = {"severity": severity, **other} + entry: dict[str, _typing.Any] = {"severity": severity, **other} if message: entry["message"] = message return _typing.cast(LogEntry, entry) -def _remove_circular(obj: _typing.Any, - refs: _typing.Set[_typing.Any] | None = None): +def _remove_circular(obj: _typing.Any, refs: set[_typing.Any] | None = None): """ Removes circular references from the given object and replaces them with "[CIRCULAR]". """ @@ -80,15 +84,13 @@ def _remove_circular(obj: _typing.Any, return "[CIRCULAR]" # For non-primitive objects, add the current object's id to the recursion stack - if not isinstance(obj, (str, int, float, bool, type(None))): + if not isinstance(obj, str | int | float | bool | type(None)): refs.add(id(obj)) # Recursively process the object based on its type result: _typing.Any if isinstance(obj, dict): - result = { - key: _remove_circular(value, refs) for key, value in obj.items() - } + result = {key: _remove_circular(value, refs) for key, value in obj.items()} elif isinstance(obj, list): result = [_remove_circular(item, refs) for item in obj] elif isinstance(obj, tuple): @@ -97,7 +99,7 @@ def _remove_circular(obj: _typing.Any, result = obj # Remove the object's id from the recursion stack after processing - if not isinstance(obj, (str, int, float, bool, type(None))): + if not isinstance(obj, str | int | float | bool | type(None)): refs.remove(id(obj)) return result @@ -111,8 +113,7 @@ def _get_write_file(severity: LogSeverity) -> _typing.TextIO: def write(entry: LogEntry) -> None: write_file = _get_write_file(entry["severity"]) - print(_json.dumps(_remove_circular(entry), ensure_ascii=False), - file=write_file) + print(_json.dumps(_remove_circular(entry), ensure_ascii=False), file=write_file) def debug(*args, **kwargs) -> None: diff --git a/src/firebase_functions/options.py b/src/firebase_functions/options.py index 2f7db7da..badf87e5 100644 --- a/src/firebase_functions/options.py +++ b/src/firebase_functions/options.py @@ -15,23 +15,25 @@ Module for options that can be used to configure Cloud Functions deployments. """ + # pylint: disable=protected-access -import enum as _enum import dataclasses as _dataclasses +import enum as _enum import re as _re import typing as _typing from zoneinfo import ZoneInfo as _ZoneInfo import firebase_functions.private.manifest as _manifest -import firebase_functions.private.util as _util import firebase_functions.private.path_pattern as _path_pattern -from firebase_functions.params import SecretParam, Expression +import firebase_functions.private.util as _util +from firebase_functions.params import Expression, SecretParam Timezone = _ZoneInfo """An alias of the zoneinfo.ZoneInfo for convenience.""" RESET_VALUE = _util.Sentinel( - "Special configuration value to reset configuration to platform default.") + "Special configuration value to reset configuration to platform default." +) """Special configuration value to reset configuration to platform default.""" @@ -134,19 +136,18 @@ def __str__(self) -> str: @_dataclasses.dataclass(frozen=True) -class RateLimits(): +class RateLimits: """ How congestion control should be applied to the function. """ - max_concurrent_dispatches: int | Expression[ - int] | _util.Sentinel | None = None + + max_concurrent_dispatches: int | Expression[int] | _util.Sentinel | None = None """ The maximum number of requests that can be outstanding at a time. If left unspecified, defaults to 1000. """ - max_dispatches_per_second: int | Expression[ - int] | _util.Sentinel | None = None + max_dispatches_per_second: int | Expression[int] | _util.Sentinel | None = None """ The maximum number of requests that can be invoked per second. If left unspecified, defaults to 500. @@ -154,7 +155,7 @@ class RateLimits(): @_dataclasses.dataclass(frozen=True) -class RetryConfig(): +class RetryConfig: """ How a task should be retried in the event of a non-2xx return. """ @@ -352,8 +353,7 @@ def convert_secret(secret) -> str: secret_value = secret.name return secret_value - merged_options["secrets"] = list( - map(convert_secret, _typing.cast(list, self.secrets))) + merged_options["secrets"] = list(map(convert_secret, _typing.cast(list, self.secrets))) # _util.Sentinel values are converted to `None` in ManifestEndpoint generation # after other None values are removed - so as to keep them in the generated # YAML output as 'null' values. @@ -363,17 +363,14 @@ def _endpoint(self, **kwargs) -> _manifest.ManifestEndpoint: assert kwargs["func_name"] is not None options_dict = self._asdict_with_global_options() options = self.__class__(**options_dict) - secret_envs: list[ - _manifest.SecretEnvironmentVariable] | _util.Sentinel = [] + secret_envs: list[_manifest.SecretEnvironmentVariable] | _util.Sentinel = [] if options.secrets is not None: if isinstance(options.secrets, list): - def convert_secret( - secret) -> _manifest.SecretEnvironmentVariable: + def convert_secret(secret) -> _manifest.SecretEnvironmentVariable: return {"key": secret} - secret_envs = list( - map(convert_secret, _typing.cast(list, options.secrets))) + secret_envs = list(map(convert_secret, _typing.cast(list, options.secrets))) elif options.secrets is _util.Sentinel: secret_envs = _typing.cast(_util.Sentinel, options.secrets) @@ -385,16 +382,16 @@ def convert_secret( vpc: _manifest.VpcSettings | None = None if isinstance(options.vpc_connector, str): - vpc = ({ - "connector": - options.vpc_connector, - "egressSettings": - options.vpc_connector_egress_settings.value if isinstance( - options.vpc_connector_egress_settings, VpcEgressSetting) - else options.vpc_connector_egress_settings - } if options.vpc_connector_egress_settings is not None else { - "connector": options.vpc_connector - }) + vpc = ( + { + "connector": options.vpc_connector, + "egressSettings": options.vpc_connector_egress_settings.value + if isinstance(options.vpc_connector_egress_settings, VpcEgressSetting) + else options.vpc_connector_egress_settings, + } + if options.vpc_connector_egress_settings is not None + else {"connector": options.vpc_connector} + ) endpoint = _manifest.ManifestEndpoint( entryPoint=kwargs["func_name"], @@ -445,29 +442,35 @@ def _endpoint( self, **kwargs, ) -> _manifest.ManifestEndpoint: - rate_limits: _manifest.RateLimits | None = _manifest.RateLimits( - maxConcurrentDispatches=self.rate_limits.max_concurrent_dispatches, - maxDispatchesPerSecond=self.rate_limits.max_dispatches_per_second, - ) if self.rate_limits is not None else None - - retry_config: _manifest.RetryConfigTasks | None = _manifest.RetryConfigTasks( - maxAttempts=self.retry_config.max_attempts, - maxRetrySeconds=self.retry_config.max_retry_seconds, - maxBackoffSeconds=self.retry_config.max_backoff_seconds, - maxDoublings=self.retry_config.max_doublings, - minBackoffSeconds=self.retry_config.min_backoff_seconds, - ) if self.retry_config is not None else None + rate_limits: _manifest.RateLimits | None = ( + _manifest.RateLimits( + maxConcurrentDispatches=self.rate_limits.max_concurrent_dispatches, + maxDispatchesPerSecond=self.rate_limits.max_dispatches_per_second, + ) + if self.rate_limits is not None + else None + ) + + retry_config: _manifest.RetryConfigTasks | None = ( + _manifest.RetryConfigTasks( + maxAttempts=self.retry_config.max_attempts, + maxRetrySeconds=self.retry_config.max_retry_seconds, + maxBackoffSeconds=self.retry_config.max_backoff_seconds, + maxDoublings=self.retry_config.max_doublings, + minBackoffSeconds=self.retry_config.min_backoff_seconds, + ) + if self.retry_config is not None + else None + ) kwargs_merged = { **_dataclasses.asdict(super()._endpoint(**kwargs)), - "taskQueueTrigger": - _manifest.TaskQueueTrigger( - rateLimits=rate_limits, - retryConfig=retry_config, - ), + "taskQueueTrigger": _manifest.TaskQueueTrigger( + rateLimits=rate_limits, + retryConfig=retry_config, + ), } - return _manifest.ManifestEndpoint( - **_typing.cast(_typing.Dict, kwargs_merged)) + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) def _required_apis(self) -> list[_manifest.ManifestRequiredApi]: return [ @@ -506,11 +509,9 @@ def _endpoint( kwargs_merged = { **_dataclasses.asdict(super()._endpoint(**kwargs)), - "eventTrigger": - event_trigger, + "eventTrigger": event_trigger, } - return _manifest.ManifestEndpoint( - **_typing.cast(_typing.Dict, kwargs_merged)) + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) @_dataclasses.dataclass(frozen=True, kw_only=True) @@ -533,10 +534,14 @@ def _endpoint( "topic": self.topic, } event_type = "google.cloud.pubsub.topic.v1.messagePublished" - return _manifest.ManifestEndpoint(**_typing.cast( - _typing.Dict, - _dataclasses.asdict(super()._endpoint( - **kwargs, event_filters=event_filters, event_type=event_type)))) + return _manifest.ManifestEndpoint( + **_typing.cast( + dict, + _dataclasses.asdict( + super()._endpoint(**kwargs, event_filters=event_filters, event_type=event_type) + ), + ) + ) class AlertType(str, _enum.Enum): @@ -633,13 +638,18 @@ def _endpoint( event_filters["appid"] = self.app_id event_type = "google.firebase.firebasealerts.alerts.v1.published" - return _manifest.ManifestEndpoint(**_typing.cast( - _typing.Dict, - _dataclasses.asdict(super()._endpoint( - **kwargs, - event_filters=event_filters, - event_type=event_type, - )))) + return _manifest.ManifestEndpoint( + **_typing.cast( + dict, + _dataclasses.asdict( + super()._endpoint( + **kwargs, + event_filters=event_filters, + event_type=event_type, + ) + ), + ) + ) @_dataclasses.dataclass(frozen=True, kw_only=True) @@ -724,7 +734,8 @@ def _endpoint( ) -> _manifest.ManifestEndpoint: assert kwargs["alert_type"] is not None return FirebaseAlertOptions( - alert_type=kwargs["alert_type"],)._endpoint(**kwargs) + alert_type=kwargs["alert_type"], + )._endpoint(**kwargs) @_dataclasses.dataclass(frozen=True, kw_only=True) @@ -765,16 +776,22 @@ def _endpoint( **kwargs, ) -> _manifest.ManifestEndpoint: event_filters = {} if self.filters is None else self.filters - endpoint = _manifest.ManifestEndpoint(**_typing.cast( - _typing.Dict, - _dataclasses.asdict(super()._endpoint( - **kwargs, - event_filters=event_filters, - event_type=self.event_type, - )))) + endpoint = _manifest.ManifestEndpoint( + **_typing.cast( + dict, + _dataclasses.asdict( + super()._endpoint( + **kwargs, + event_filters=event_filters, + event_type=self.event_type, + ) + ), + ) + ) assert endpoint.eventTrigger is not None - channel = (self.channel if self.channel is not None else - "locations/us-central1/channels/firebase") + channel = ( + self.channel if self.channel is not None else "locations/us-central1/channels/firebase" + ) endpoint.eventTrigger["channel"] = channel return endpoint @@ -848,15 +865,13 @@ def _endpoint( kwargs_merged = { **_dataclasses.asdict(super()._endpoint(**kwargs)), - "scheduleTrigger": - _manifest.ScheduleTrigger( - schedule=self.schedule, - timeZone=time_zone, - retryConfig=retry_config, - ), + "scheduleTrigger": _manifest.ScheduleTrigger( + schedule=self.schedule, + timeZone=time_zone, + retryConfig=retry_config, + ), } - return _manifest.ManifestEndpoint( - **_typing.cast(_typing.Dict, kwargs_merged)) + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) def _required_apis(self) -> list[_manifest.ManifestRequiredApi]: return [ @@ -893,7 +908,8 @@ def _endpoint( raise ValueError( "Missing bucket name. If you are unit testing, please specify a bucket name" " by providing a bucket name directly to the event handler or by setting the" - " FIREBASE_CONFIG environment variable.") + " FIREBASE_CONFIG environment variable." + ) event_filters: _typing.Any = { "bucket": bucket, } @@ -905,11 +921,9 @@ def _endpoint( kwargs_merged = { **_dataclasses.asdict(super()._endpoint(**kwargs)), - "eventTrigger": - event_trigger, + "eventTrigger": event_trigger, } - return _manifest.ManifestEndpoint( - **_typing.cast(_typing.Dict, kwargs_merged)) + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) @_dataclasses.dataclass(frozen=True, kw_only=True) @@ -961,11 +975,9 @@ def _endpoint( kwargs_merged = { **_dataclasses.asdict(super()._endpoint(**kwargs)), - "eventTrigger": - event_trigger, + "eventTrigger": event_trigger, } - return _manifest.ManifestEndpoint( - **_typing.cast(_typing.Dict, kwargs_merged)) + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) @_dataclasses.dataclass(frozen=True, kw_only=True) @@ -1000,20 +1012,16 @@ def _endpoint( eventType=kwargs["event_type"], options=_manifest.BlockingTriggerOptions( idToken=self.id_token if self.id_token is not None else False, - accessToken=self.access_token - if self.access_token is not None else False, - refreshToken=self.refresh_token - if self.refresh_token is not None else False, + accessToken=self.access_token if self.access_token is not None else False, + refreshToken=self.refresh_token if self.refresh_token is not None else False, ), ) kwargs_merged = { **_dataclasses.asdict(super()._endpoint(**kwargs)), - "blockingTrigger": - blocking_trigger, + "blockingTrigger": blocking_trigger, } - return _manifest.ManifestEndpoint( - **_typing.cast(_typing.Dict, kwargs_merged)) + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) def _required_apis(self) -> list[_manifest.ManifestRequiredApi]: return [ @@ -1057,10 +1065,8 @@ def _endpoint( document_pattern: _path_pattern.PathPattern = kwargs["document_pattern"] event_filter_document = document_pattern.value event_filters: _typing.Any = { - "database": - self.database if self.database is not None else "(default)", - "namespace": - self.namespace if self.namespace is not None else "(default)", + "database": self.database if self.database is not None else "(default)", + "namespace": self.namespace if self.namespace is not None else "(default)", } event_filters_path_patterns: _typing.Any = {} if document_pattern.has_wildcards: @@ -1076,11 +1082,9 @@ def _endpoint( kwargs_merged = { **_dataclasses.asdict(super()._endpoint(**kwargs)), - "eventTrigger": - event_trigger, + "eventTrigger": event_trigger, } - return _manifest.ManifestEndpoint( - **_typing.cast(_typing.Dict, kwargs_merged)) + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) @_dataclasses.dataclass(frozen=True, kw_only=True) @@ -1090,8 +1094,7 @@ class HttpsOptions(RuntimeOptions): Internal use only. """ - invoker: str | list[str] | _typing.Literal["public", - "private"] | None = None + invoker: str | list[str] | _typing.Literal["public", "private"] | None = None """ Invoker to set access control on HTTP functions. """ @@ -1132,9 +1135,9 @@ def _endpoint( invoker = self.invoker if isinstance(invoker, str): invoker = [invoker] - assert len( - invoker - ) >= 1, "HttpsOptions: Invalid option for invoker - must be a non-empty list." + assert len(invoker) >= 1, ( + "HttpsOptions: Invalid option for invoker - must be a non-empty list." + ) assert "" not in invoker, ( "HttpsOptions: Invalid option for invoker - must be a non-empty string." ) @@ -1146,8 +1149,7 @@ def _endpoint( https_trigger["invoker"] = invoker kwargs_merged["httpsTrigger"] = https_trigger - return _manifest.ManifestEndpoint( - **_typing.cast(_typing.Dict, kwargs_merged)) + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) _GLOBAL_OPTIONS = RuntimeOptions() diff --git a/src/firebase_functions/params.py b/src/firebase_functions/params.py index 4aa74062..32853f08 100644 --- a/src/firebase_functions/params.py +++ b/src/firebase_functions/params.py @@ -14,11 +14,11 @@ """Module for params that can make Cloud Functions codebases generic.""" import abc as _abc -import json as _json import dataclasses as _dataclasses +import enum as _enum +import json as _json import os as _os import re as _re -import enum as _enum import typing as _typing _T = _typing.TypeVar("_T", str, int, float, bool, list) @@ -47,13 +47,11 @@ def value(self) -> _T: def _obj_cel_name(obj: _T) -> _T: - return obj if not isinstance(obj, Expression) else object.__getattribute__( - obj, "_cel_") + return obj if not isinstance(obj, Expression) else object.__getattribute__(obj, "_cel_") def _quote_if_string(literal: _T) -> _T: - return _obj_cel_name(literal) if not isinstance(literal, - str) else f'"{literal}"' + return _obj_cel_name(literal) if not isinstance(literal, str) else f'"{literal}"' _params: dict[str, Expression] = {} @@ -65,6 +63,7 @@ class TernaryExpression(Expression[_T], _typing.Generic[_T]): A CEL expression that evaluates to one of two values based on the value of another expression. """ + test: Expression[bool] if_true: _T if_false: _T @@ -87,6 +86,7 @@ class CompareExpression(Expression[bool], _typing.Generic[_T]): A CEL expression that evaluates to boolean true or false based on a comparison between the value of another expression and a literal of that same type. """ + comparator: str left: Expression[_T] right: _T @@ -141,7 +141,7 @@ class SelectInput(_typing.Generic[_T]): @_dataclasses.dataclass(frozen=True) -class MultiSelectInput(): +class MultiSelectInput: """ Specifies that a Param's value should be determined by having the user select a subset from a list of pre-canned options interactively at deploy-time. @@ -179,6 +179,7 @@ class TextInput: class ResourceType(str, _enum.Enum): """The type of resource that a picker should pick.""" + STORAGE_BUCKET = "storage.googleapis.com/Bucket" def __str__(self) -> str: @@ -231,8 +232,7 @@ class Param(Expression[_T]): deployments. """ - input: TextInput | ResourceInput | SelectInput[ - _T] | MultiSelectInput | None = None + input: TextInput | ResourceInput | SelectInput[_T] | MultiSelectInput | None = None """ The type of input that is required for this param, e.g. TextInput. """ @@ -254,7 +254,8 @@ def __post_init__(self): if not _re.match(r"^[A-Z0-9_]+$", self.name): raise ValueError( "Parameter names must only use uppercase letters, numbers and " - "underscores, e.g. 'UPPER_SNAKE_CASE'.") + "underscores, e.g. 'UPPER_SNAKE_CASE'." + ) if self.name in _params: raise ValueError( f"Duplicate Parameter Error: The parameter '{self.name}' has already been declared." @@ -294,7 +295,8 @@ def __post_init__(self): if not _re.match(r"^[A-Z0-9_]+$", self.name): raise ValueError( "Parameter names must only use uppercase letters, numbers and " - "underscores, e.g. 'UPPER_SNAKE_CASE'.") + "underscores, e.g. 'UPPER_SNAKE_CASE'." + ) if self.name in _params: raise ValueError( f"Duplicate Parameter Error: The parameter '{self.name}' has already been declared." @@ -323,10 +325,9 @@ def value(self) -> str: return _os.environ[self.name] if self.default is not None: - return self.default.value if isinstance( - self.default, Expression) else self.default + return self.default.value if isinstance(self.default, Expression) else self.default - return str() + return "" @_dataclasses.dataclass(frozen=True) @@ -338,9 +339,8 @@ def value(self) -> int: if _os.environ.get(self.name) is not None: return int(_os.environ[self.name]) if self.default is not None: - return self.default.value if isinstance( - self.default, Expression) else self.default - return int() + return self.default.value if isinstance(self.default, Expression) else self.default + return 0 @_dataclasses.dataclass(frozen=True) @@ -356,9 +356,8 @@ def value(self) -> float: if _os.environ.get(self.name) is not None: return float(_os.environ[self.name]) if self.default is not None: - return self.default.value if isinstance( - self.default, Expression) else self.default - return float() + return self.default.value if isinstance(self.default, Expression) else self.default + return 0.0 @_dataclasses.dataclass(frozen=True) @@ -371,8 +370,7 @@ def value(self) -> bool: if env_value is not None: return env_value.lower() == "true" if self.default is not None: - return self.default.value if isinstance( - self.default, Expression) else self.default + return self.default.value if isinstance(self.default, Expression) else self.default return False @@ -386,8 +384,7 @@ def value(self) -> list[str]: # If the environment variable starts with "[" and ends with "]", # then assume it is a JSON array and try to parse it. # (This is for Cloud Run (v2 Functions), the environment variable is a JSON array.) - if _os.environ[self.name].startswith("[") and _os.environ[ - self.name].endswith("]"): + if _os.environ[self.name].startswith("[") and _os.environ[self.name].endswith("]"): try: return _json.loads(_os.environ[self.name]) except _json.JSONDecodeError: @@ -397,8 +394,7 @@ def value(self) -> list[str]: # variable is a comma-separated list.) return list(filter(len, _os.environ[self.name].split(","))) if self.default is not None: - return self.default.value if isinstance( - self.default, Expression) else self.default + return self.default.value if isinstance(self.default, Expression) else self.default return [] diff --git a/src/firebase_functions/private/_alerts_fn.py b/src/firebase_functions/private/_alerts_fn.py index bd4484a6..b7796ee2 100644 --- a/src/firebase_functions/private/_alerts_fn.py +++ b/src/firebase_functions/private/_alerts_fn.py @@ -15,15 +15,17 @@ # pylint: disable=protected-access,cyclic-import import typing as _typing + import cloudevents.http as _ce +from functions_framework import logging as _logging + import firebase_functions.private.util as _util from firebase_functions.alerts import FirebaseAlertData -from functions_framework import logging as _logging - def plan_update_payload_from_ce_payload(payload: dict): from firebase_functions.alerts.billing_fn import PlanUpdatePayload + return PlanUpdatePayload( notification_type=payload["notificationType"], billing_plan=payload["billingPlan"], @@ -33,6 +35,7 @@ def plan_update_payload_from_ce_payload(payload: dict): def plan_automated_update_payload_from_ce_payload(payload: dict): from firebase_functions.alerts.billing_fn import PlanAutomatedUpdatePayload + return PlanAutomatedUpdatePayload( notification_type=payload["notificationType"], billing_plan=payload["billingPlan"], @@ -41,6 +44,7 @@ def plan_automated_update_payload_from_ce_payload(payload: dict): def in_app_feedback_payload_from_ce_payload(payload: dict): from firebase_functions.alerts.app_distribution_fn import InAppFeedbackPayload + return InAppFeedbackPayload( feedback_report=payload["feedbackReport"], feedback_console_uri=payload["feedbackConsoleUri"], @@ -54,6 +58,7 @@ def in_app_feedback_payload_from_ce_payload(payload: dict): def new_tester_device_payload_from_ce_payload(payload: dict): from firebase_functions.alerts.app_distribution_fn import NewTesterDevicePayload + return NewTesterDevicePayload( tester_name=payload["testerName"], tester_email=payload["testerEmail"], @@ -64,6 +69,7 @@ def new_tester_device_payload_from_ce_payload(payload: dict): def threshold_alert_payload_from_ce_payload(payload: dict): from firebase_functions.alerts.performance_fn import ThresholdAlertPayload + return ThresholdAlertPayload( event_name=payload["eventName"], event_type=payload["eventType"], @@ -81,6 +87,7 @@ def threshold_alert_payload_from_ce_payload(payload: dict): def issue_from_ce_payload(payload: dict): from firebase_functions.alerts.crashlytics_fn import Issue + return Issue( id=payload["id"], title=payload["title"], @@ -91,21 +98,24 @@ def issue_from_ce_payload(payload: dict): def new_fatal_issue_payload_from_ce_payload(payload: dict): from firebase_functions.alerts.crashlytics_fn import NewFatalIssuePayload + return NewFatalIssuePayload(issue=issue_from_ce_payload(payload["issue"])) def new_nonfatal_issue_payload_from_ce_payload(payload: dict): from firebase_functions.alerts.crashlytics_fn import NewNonfatalIssuePayload - return NewNonfatalIssuePayload( - issue=issue_from_ce_payload(payload["issue"])) + + return NewNonfatalIssuePayload(issue=issue_from_ce_payload(payload["issue"])) def regression_alert_payload_from_ce_payload(payload: dict): from firebase_functions.alerts.crashlytics_fn import RegressionAlertPayload - return RegressionAlertPayload(type=payload["type"], - issue=issue_from_ce_payload(payload["issue"]), - resolve_time=_util.timestamp_conversion( - payload["resolveTime"])) + + return RegressionAlertPayload( + type=payload["type"], + issue=issue_from_ce_payload(payload["issue"]), + resolve_time=_util.timestamp_conversion(payload["resolveTime"]), + ) def trending_issue_details_from_ce_payload(payload: dict): @@ -125,8 +135,7 @@ def stability_digest_payload_from_ce_payload(payload: dict): return StabilityDigestPayload( digest_date=_util.timestamp_conversion(payload["digestDate"]), trending_issues=[ - trending_issue_details_from_ce_payload(issue) - for issue in payload["trendingIssues"] + trending_issue_details_from_ce_payload(issue) for issue in payload["trendingIssues"] ], ) @@ -149,7 +158,9 @@ def new_anr_issue_payload_from_ce_payload(payload: dict): return NewAnrIssuePayload(issue=issue_from_ce_payload(payload["issue"])) -def firebase_alert_data_from_ce(event_dict: dict,) -> FirebaseAlertData: +def firebase_alert_data_from_ce( + event_dict: dict, +) -> FirebaseAlertData: from firebase_functions.options import AlertType alert_type: str = event_dict["alerttype"] @@ -157,8 +168,7 @@ def firebase_alert_data_from_ce(event_dict: dict,) -> FirebaseAlertData: if alert_type == AlertType.CRASHLYTICS_NEW_FATAL_ISSUE.value: alert_payload = new_fatal_issue_payload_from_ce_payload(alert_payload) elif alert_type == AlertType.CRASHLYTICS_NEW_NONFATAL_ISSUE.value: - alert_payload = new_nonfatal_issue_payload_from_ce_payload( - alert_payload) + alert_payload = new_nonfatal_issue_payload_from_ce_payload(alert_payload) elif alert_type == AlertType.CRASHLYTICS_REGRESSION.value: alert_payload = regression_alert_payload_from_ce_payload(alert_payload) elif alert_type == AlertType.CRASHLYTICS_STABILITY_DIGEST.value: @@ -170,8 +180,7 @@ def firebase_alert_data_from_ce(event_dict: dict,) -> FirebaseAlertData: elif alert_type == AlertType.BILLING_PLAN_UPDATE.value: alert_payload = plan_update_payload_from_ce_payload(alert_payload) elif alert_type == AlertType.BILLING_PLAN_AUTOMATED_UPDATE.value: - alert_payload = plan_automated_update_payload_from_ce_payload( - alert_payload) + alert_payload = plan_automated_update_payload_from_ce_payload(alert_payload) elif alert_type == AlertType.APP_DISTRIBUTION_NEW_TESTER_IOS_DEVICE.value: alert_payload = new_tester_device_payload_from_ce_payload(alert_payload) elif alert_type == AlertType.APP_DISTRIBUTION_IN_APP_FEEDBACK.value: @@ -184,7 +193,8 @@ def firebase_alert_data_from_ce(event_dict: dict,) -> FirebaseAlertData: return FirebaseAlertData( create_time=_util.timestamp_conversion(event_dict["createTime"]), end_time=_util.timestamp_conversion(event_dict["endTime"]) - if "endTime" in event_dict else None, + if "endTime" in event_dict + else None, payload=alert_payload, ) diff --git a/src/firebase_functions/private/_identity_fn.py b/src/firebase_functions/private/_identity_fn.py index f13d150a..0e09c04a 100644 --- a/src/firebase_functions/private/_identity_fn.py +++ b/src/firebase_functions/private/_identity_fn.py @@ -12,25 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. """Cloud functions to handle Eventarc events.""" + # pylint: disable=protected-access -import typing as _typing import datetime as _dt -import time as _time import json as _json +import time as _time +import typing as _typing -from firebase_functions.core import _with_init -from firebase_functions.https_fn import HttpsError, FunctionsErrorCode - -import firebase_functions.private.util as _util -import firebase_functions.private.token_verifier as _token_verifier from flask import ( Request as _Request, +) +from flask import ( Response as _Response, - make_response as _make_response, +) +from flask import ( jsonify as _jsonify, ) +from flask import ( + make_response as _make_response, +) from functions_framework import logging as _logging +import firebase_functions.private.token_verifier as _token_verifier +import firebase_functions.private.util as _util +from firebase_functions.core import _with_init +from firebase_functions.https_fn import FunctionsErrorCode, HttpsError + _claims_max_payload_size = 1000 _disallowed_custom_claims = [ "acr", @@ -53,6 +60,7 @@ def _auth_user_info_from_token_data(token_data: dict[str, _typing.Any]): from firebase_functions.identity_fn import AuthUserInfo + return AuthUserInfo( uid=token_data["uid"], provider_id=token_data["provider_id"], @@ -65,24 +73,24 @@ def _auth_user_info_from_token_data(token_data: dict[str, _typing.Any]): def _auth_user_metadata_from_token_data(token_data: dict[str, _typing.Any]): from firebase_functions.identity_fn import AuthUserMetadata - creation_time = _dt.datetime.utcfromtimestamp( - int(token_data["creation_time"]) / 1000.0) + + creation_time = _dt.datetime.utcfromtimestamp(int(token_data["creation_time"]) / 1000.0) last_sign_in_time = None if "last_sign_in_time" in token_data: last_sign_in_time = _dt.datetime.utcfromtimestamp( - int(token_data["last_sign_in_time"]) / 1000.0) + int(token_data["last_sign_in_time"]) / 1000.0 + ) - return AuthUserMetadata(creation_time=creation_time, - last_sign_in_time=last_sign_in_time) + return AuthUserMetadata(creation_time=creation_time, last_sign_in_time=last_sign_in_time) def _auth_multi_factor_info_from_token_data(token_data: dict[str, _typing.Any]): from firebase_functions.identity_fn import AuthMultiFactorInfo + enrollment_time = token_data.get("enrollment_time") if enrollment_time: enrollment_time = _dt.datetime.fromisoformat(enrollment_time) - factor_id = token_data["factor_id"] if not token_data.get( - "phone_number") else "phone" + factor_id = token_data["factor_id"] if not token_data.get("phone_number") else "phone" return AuthMultiFactorInfo( uid=token_data["uid"], factor_id=factor_id, @@ -92,8 +100,7 @@ def _auth_multi_factor_info_from_token_data(token_data: dict[str, _typing.Any]): ) -def _auth_multi_factor_settings_from_token_data(token_data: dict[str, - _typing.Any]): +def _auth_multi_factor_settings_from_token_data(token_data: dict[str, _typing.Any]): if not token_data: return None @@ -112,6 +119,7 @@ def _auth_multi_factor_settings_from_token_data(token_data: dict[str, def _auth_user_record_from_token_data(token_data: dict[str, _typing.Any]): from firebase_functions.identity_fn import AuthUserRecord + return AuthUserRecord( uid=token_data["uid"], email=token_data.get("email"), @@ -122,24 +130,24 @@ def _auth_user_record_from_token_data(token_data: dict[str, _typing.Any]): disabled=token_data.get("disabled", False), metadata=_auth_user_metadata_from_token_data(token_data["metadata"]), provider_data=[ - _auth_user_info_from_token_data(info) - for info in token_data["provider_data"] + _auth_user_info_from_token_data(info) for info in token_data["provider_data"] ], password_hash=token_data.get("password_hash"), password_salt=token_data.get("password_salt"), custom_claims=token_data.get("custom_claims"), tenant_id=token_data.get("tenant_id"), - tokens_valid_after_time=_dt.datetime.utcfromtimestamp( - token_data["tokens_valid_after_time"]) - if token_data.get("tokens_valid_after_time") else None, - multi_factor=_auth_multi_factor_settings_from_token_data( - token_data["multi_factor"]) - if "multi_factor" in token_data else None, + tokens_valid_after_time=_dt.datetime.utcfromtimestamp(token_data["tokens_valid_after_time"]) + if token_data.get("tokens_valid_after_time") + else None, + multi_factor=_auth_multi_factor_settings_from_token_data(token_data["multi_factor"]) + if "multi_factor" in token_data + else None, ) def _additional_user_info_from_token_data(token_data: dict[str, _typing.Any]): from firebase_functions.identity_fn import AdditionalUserInfo + raw_user_info = token_data.get("raw_user_info") profile = None username = None @@ -155,9 +163,11 @@ def _additional_user_info_from_token_data(token_data: dict[str, _typing.Any]): elif sign_in_method == "twitter.com": username = profile.get("screen_name") - provider_id: str = ("password" - if token_data.get("sign_in_method") == "emailLink" else - str(token_data.get("sign_in_method"))) + provider_id: str = ( + "password" + if token_data.get("sign_in_method") == "emailLink" + else str(token_data.get("sign_in_method")) + ) is_new_user = token_data.get("event_type") == "beforeCreate" @@ -170,23 +180,27 @@ def _additional_user_info_from_token_data(token_data: dict[str, _typing.Any]): ) -def _credential_from_token_data(token_data: dict[str, _typing.Any], - time: float): - if (not token_data.get("sign_in_attributes") and - not token_data.get("oauth_id_token") and - not token_data.get("oauth_access_token") and - not token_data.get("oauth_refresh_token")): +def _credential_from_token_data(token_data: dict[str, _typing.Any], time: float): + if ( + not token_data.get("sign_in_attributes") + and not token_data.get("oauth_id_token") + and not token_data.get("oauth_access_token") + and not token_data.get("oauth_refresh_token") + ): return None from firebase_functions.identity_fn import Credential oauth_expires_in = token_data.get("oauth_expires_in") - expiration_time = (_dt.datetime.utcfromtimestamp(time + oauth_expires_in) - if oauth_expires_in else None) + expiration_time = ( + _dt.datetime.utcfromtimestamp(time + oauth_expires_in) if oauth_expires_in else None + ) - provider_id: str = ("password" - if token_data.get("sign_in_method") == "emailLink" else - str(token_data.get("sign_in_method"))) + provider_id: str = ( + "password" + if token_data.get("sign_in_method") == "emailLink" + else str(token_data.get("sign_in_method")) + ) return Credential( claims=token_data.get("sign_in_attributes"), @@ -200,9 +214,9 @@ def _credential_from_token_data(token_data: dict[str, _typing.Any], ) -def _auth_blocking_event_from_token_data(event_type: str, - token_data: dict[str, _typing.Any]): +def _auth_blocking_event_from_token_data(event_type: str, token_data: dict[str, _typing.Any]): from firebase_functions.identity_fn import AuthBlockingEvent + return AuthBlockingEvent( data=_auth_user_record_from_token_data(token_data["user_record"]), locale=token_data.get("locale"), @@ -229,10 +243,8 @@ def _validate_auth_response( if auth_response is None: auth_response = {} - custom_claims: dict[str, - _typing.Any] | None = auth_response.get("custom_claims") - session_claims: dict[str, _typing.Any] | None = auth_response.get( - "session_claims") + custom_claims: dict[str, _typing.Any] | None = auth_response.get("custom_claims") + session_claims: dict[str, _typing.Any] | None = auth_response.get("session_claims") if session_claims and event_type == event_type_before_create: raise HttpsError( @@ -242,10 +254,7 @@ def _validate_auth_response( ) if custom_claims: - invalid_claims = [ - claim for claim in _disallowed_custom_claims - if claim in custom_claims - ] + invalid_claims = [claim for claim in _disallowed_custom_claims if claim in custom_claims] if invalid_claims: raise HttpsError( @@ -262,11 +271,7 @@ def _validate_auth_response( ) if event_type == event_type_before_sign_in and session_claims: - - invalid_claims = [ - claim for claim in _disallowed_custom_claims - if claim in session_claims - ] + invalid_claims = [claim for claim in _disallowed_custom_claims if claim in session_claims] if invalid_claims: raise HttpsError( @@ -282,10 +287,7 @@ def _validate_auth_response( f"{_claims_max_payload_size} characters.", ) - combined_claims = { - **(custom_claims if custom_claims else {}), - **session_claims - } + combined_claims = {**(custom_claims if custom_claims else {}), **session_claims} if len(_json.dumps(combined_claims)) > _claims_max_payload_size: raise HttpsError( @@ -309,27 +311,23 @@ def _validate_auth_response( if "session_claims" in auth_response_keys: auth_response_dict["sessionClaims"] = auth_response["session_claims"] if "recaptcha_action_override" in auth_response_keys: - auth_response_dict["recaptchaActionOverride"] = auth_response[ - "recaptcha_action_override"] + auth_response_dict["recaptchaActionOverride"] = auth_response["recaptcha_action_override"] return auth_response_dict def _generate_response_payload( - auth_response_dict: dict[str, _typing.Any] | None + auth_response_dict: dict[str, _typing.Any] | None, ) -> dict[str, _typing.Any]: if not auth_response_dict: return {} formatted_auth_response = auth_response_dict.copy() - recaptcha_action_override = formatted_auth_response.pop( - "recaptchaActionOverride", None) + recaptcha_action_override = formatted_auth_response.pop("recaptchaActionOverride", None) result = {} update_mask = ",".join(formatted_auth_response.keys()) if len(update_mask) != 0: - result["userRecord"] = { - **formatted_auth_response, "updateMask": update_mask - } + result["userRecord"] = {**formatted_auth_response, "updateMask": update_mask} if recaptcha_action_override is not None: result["recaptchaActionOverride"] = recaptcha_action_override @@ -343,6 +341,7 @@ def before_operation_handler( request: _Request, ) -> _Response: from firebase_functions.identity_fn import BeforeCreateResponse, BeforeSignInResponse + try: if not _util.valid_on_call_request(request): _logging.error("Invalid request, unable to process.") @@ -356,8 +355,7 @@ def before_operation_handler( jwt_token = request.json["data"]["jwt"] decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token) event = _auth_blocking_event_from_token_data(event_type, decoded_token) - auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init( - func)(event) + auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(func)(event) if not auth_response: return _jsonify({}) auth_response_dict = _validate_auth_response(event_type, auth_response) diff --git a/src/firebase_functions/private/manifest.py b/src/firebase_functions/private/manifest.py index 7ebeac2e..7672a9f5 100644 --- a/src/firebase_functions/private/manifest.py +++ b/src/firebase_functions/private/manifest.py @@ -19,11 +19,12 @@ import dataclasses as _dataclasses import typing as _typing +from enum import Enum as _Enum + import typing_extensions as _typing_extensions import firebase_functions.params as _params import firebase_functions.private.util as _util -from enum import Enum as _Enum class SecretEnvironmentVariable(_typing.TypedDict): @@ -55,52 +56,57 @@ class EventTrigger(_typing.TypedDict): Trigger definitions for endpoints that listen to CloudEvents emitted by other systems (or legacy Google events for GCF gen 1) """ - eventFilters: _typing_extensions.NotRequired[dict[str, str | - _params.Expression[str]]] - eventFilterPathPatterns: _typing_extensions.NotRequired[dict[ - str, str | _params.Expression[str]]] + + eventFilters: _typing_extensions.NotRequired[dict[str, str | _params.Expression[str]]] + eventFilterPathPatterns: _typing_extensions.NotRequired[ + dict[str, str | _params.Expression[str]] + ] channel: _typing_extensions.NotRequired[str] eventType: _typing_extensions.Required[str] - retry: _typing_extensions.Required[bool | _params.Expression[bool] | - _util.Sentinel] + retry: _typing_extensions.Required[bool | _params.Expression[bool] | _util.Sentinel] class RetryConfigBase(_typing.TypedDict): """ Retry configuration for a endpoint. """ - maxRetrySeconds: _typing_extensions.NotRequired[int | - _params.Expression[int] | - _util.Sentinel | None] - maxBackoffSeconds: _typing_extensions.NotRequired[int | - _params.Expression[int] | - _util.Sentinel | None] - maxDoublings: _typing_extensions.NotRequired[int | _params.Expression[int] | - _util.Sentinel | None] - minBackoffSeconds: _typing_extensions.NotRequired[int | - _params.Expression[int] | - _util.Sentinel | None] + + maxRetrySeconds: _typing_extensions.NotRequired[ + int | _params.Expression[int] | _util.Sentinel | None + ] + maxBackoffSeconds: _typing_extensions.NotRequired[ + int | _params.Expression[int] | _util.Sentinel | None + ] + maxDoublings: _typing_extensions.NotRequired[ + int | _params.Expression[int] | _util.Sentinel | None + ] + minBackoffSeconds: _typing_extensions.NotRequired[ + int | _params.Expression[int] | _util.Sentinel | None + ] class RetryConfigTasks(RetryConfigBase): """ Retry configuration for a task. """ - maxAttempts: _typing_extensions.NotRequired[int | _params.Expression[int] | - _util.Sentinel | None] + + maxAttempts: _typing_extensions.NotRequired[ + int | _params.Expression[int] | _util.Sentinel | None + ] class RetryConfigScheduler(RetryConfigBase): """ Retry configuration for a schedule. """ - retryCount: _typing_extensions.NotRequired[int | _params.Expression[int] | - _util.Sentinel | None] + + retryCount: _typing_extensions.NotRequired[ + int | _params.Expression[int] | _util.Sentinel | None + ] class RateLimits(_typing.TypedDict): - maxConcurrentDispatches: int | _params.Expression[ - int] | _util.Sentinel | None + maxConcurrentDispatches: int | _params.Expression[int] | _util.Sentinel | None maxDispatchesPerSecond: int | _params.Expression[int] | _util.Sentinel | None @@ -110,6 +116,7 @@ class TaskQueueTrigger(_typing.TypedDict): Trigger definitions for RPCs servers using the HTTP protocol defined at https://firebase.google.com/docs/functions/callable-reference """ + retryConfig: RetryConfigTasks | None rateLimits: RateLimits | None @@ -143,8 +150,7 @@ class ManifestEndpoint: entryPoint: str | None = None region: list[str] | None = _dataclasses.field(default_factory=list[str]) platform: str | None = "gcfv2" - availableMemoryMb: int | _params.Expression[ - int] | _util.Sentinel | None = None + availableMemoryMb: int | _params.Expression[int] | _util.Sentinel | None = None maxInstances: int | _params.Expression[int] | _util.Sentinel | None = None minInstances: int | _params.Expression[int] | _util.Sentinel | None = None concurrency: int | _params.Expression[int] | _util.Sentinel | None = None @@ -154,9 +160,9 @@ class ManifestEndpoint: vpc: VpcSettings | None = None labels: dict[str, str] | None = None ingressSettings: str | None | _util.Sentinel = None - secretEnvironmentVariables: list[ - SecretEnvironmentVariable] | _util.Sentinel | None = _dataclasses.field( - default_factory=list[SecretEnvironmentVariable]) + secretEnvironmentVariables: list[SecretEnvironmentVariable] | _util.Sentinel | None = ( + _dataclasses.field(default_factory=list[SecretEnvironmentVariable]) + ) httpsTrigger: HttpsTrigger | None = None callableTrigger: CallableTrigger | None = None eventTrigger: EventTrigger | None = None @@ -174,27 +180,28 @@ class ManifestRequiredApi(_typing.TypedDict): class ManifestStack: endpoints: dict[str, ManifestEndpoint] specVersion: str = "v1alpha1" - params: list[_typing.Any] | None = _dataclasses.field( - default_factory=list[_typing.Any]) + params: list[_typing.Any] | None = _dataclasses.field(default_factory=list[_typing.Any]) requiredAPIs: list[ManifestRequiredApi] = _dataclasses.field( - default_factory=list[ManifestRequiredApi]) + default_factory=list[ManifestRequiredApi] + ) def _param_input_to_spec( - param_input: _params.TextInput | _params.ResourceInput | - _params.SelectInput | _params.MultiSelectInput + param_input: _params.TextInput + | _params.ResourceInput + | _params.SelectInput + | _params.MultiSelectInput, ) -> dict[str, _typing.Any]: if isinstance(param_input, _params.TextInput): return { "text": { - key: value for key, value in { - "example": - param_input.example, - "validationRegex": - param_input.validation_regex, - "validationErrorMessage": - param_input.validation_error_message, - }.items() if value is not None + key: value + for key, value in { + "example": param_input.example, + "validationRegex": param_input.validation_regex, + "validationErrorMessage": param_input.validation_error_message, + }.items() + if value is not None } } @@ -205,25 +212,28 @@ def _param_input_to_spec( }, } - if isinstance(param_input, (_params.MultiSelectInput, _params.SelectInput)): - key = "select" if isinstance(param_input, - _params.SelectInput) else "multiSelect" + if isinstance(param_input, _params.MultiSelectInput | _params.SelectInput): + key = "select" if isinstance(param_input, _params.SelectInput) else "multiSelect" return { key: { - "options": [{ - key: value for key, value in { - "value": option.value, - "label": option.label, - }.items() if value is not None - } for option in param_input.options], + "options": [ + { + key: value + for key, value in { + "value": option.value, + "label": option.label, + }.items() + if value is not None + } + for option in param_input.options + ], }, } return {} -def _param_to_spec( - param: _params.Param | _params.SecretParam) -> dict[str, _typing.Any]: +def _param_to_spec(param: _params.Param | _params.SecretParam) -> dict[str, _typing.Any]: spec_dict: dict[str, _typing.Any] = { "name": param.name, "label": param.label, @@ -232,8 +242,9 @@ def _param_to_spec( } if isinstance(param, _params.Param): - spec_dict["default"] = f"{param.default}" if isinstance( - param.default, _params.Expression) else param.default + spec_dict["default"] = ( + f"{param.default}" if isinstance(param.default, _params.Expression) else param.default + ) if param.input: spec_dict["input"] = _param_input_to_spec(param.input) @@ -270,7 +281,7 @@ def _object_to_spec(data) -> object: return data -def _dict_factory(data: list[_typing.Tuple[str, _typing.Any]]) -> dict: +def _dict_factory(data: list[tuple[str, _typing.Any]]) -> dict: out: dict = {} for key, value in data: if value is not None: diff --git a/src/firebase_functions/private/path_pattern.py b/src/firebase_functions/private/path_pattern.py index 6bd36c3b..1603dba9 100644 --- a/src/firebase_functions/private/path_pattern.py +++ b/src/firebase_functions/private/path_pattern.py @@ -13,12 +13,12 @@ # limitations under the License. """Path pattern matching utilities.""" -from enum import Enum import re +from enum import Enum def path_parts(path: str) -> list[str]: - if not path or path == "" or path == "/": + if not path or path in {"", "/"}: return [] return path.strip("/").split("/") @@ -30,7 +30,7 @@ def join_path(base: str, child: str) -> str: def trim_param(param: str) -> str: param_no_braces = param[1:-1] if "=" in param_no_braces: - return param_no_braces[:param_no_braces.index("=")] + return param_no_braces[: param_no_braces.index("=")] return param_no_braces @@ -50,6 +50,7 @@ class PathSegment: """ A segment of a path pattern. """ + name: SegmentName value: str trimmed: str @@ -89,6 +90,7 @@ class SingleCaptureSegment(PathSegment): """ A segment of a path pattern that captures a single segment. """ + name = SegmentName.SINGLE_CAPTURE def __init__(self, value): @@ -129,6 +131,7 @@ class PathPattern: Implements Eventarc's path pattern from the spec https://cloud.google.com/eventarc/docs/path-patterns """ + segments: list[PathSegment] def __init__(self, raw_path: str): @@ -157,15 +160,17 @@ def value(self) -> str: @property def has_wildcards(self) -> bool: - return any(segment.is_single_segment_wildcard or - segment.is_multi_segment_wildcard - for segment in self.segments) + return any( + segment.is_single_segment_wildcard or segment.is_multi_segment_wildcard + for segment in self.segments + ) @property def has_captures(self) -> bool: - return any(segment.name in (SegmentName.SINGLE_CAPTURE, - SegmentName.MULTI_CAPTURE) - for segment in self.segments) + return any( + segment.name in (SegmentName.SINGLE_CAPTURE, SegmentName.MULTI_CAPTURE) + for segment in self.segments + ) def extract_matches(self, path: str) -> dict[str, str]: matches: dict[str, str] = {} @@ -180,7 +185,6 @@ def extract_matches(self, path: str) -> dict[str, str]: if segment.name == SegmentName.SINGLE_CAPTURE: matches[segment.trimmed] = path_segments[path_ndx] elif segment.name == SegmentName.MULTI_CAPTURE: - matches[segment.trimmed] = "/".join( - path_segments[path_ndx:next_path_ndx]) + matches[segment.trimmed] = "/".join(path_segments[path_ndx:next_path_ndx]) path_ndx = next_path_ndx if segment.is_multi_segment_wildcard else path_ndx + 1 return matches diff --git a/src/firebase_functions/private/serving.py b/src/firebase_functions/private/serving.py index 1cfde330..7a2f21c4 100644 --- a/src/firebase_functions/private/serving.py +++ b/src/firebase_functions/private/serving.py @@ -14,21 +14,22 @@ """ Module used to serve Firebase functions locally and remotely. """ + # pylint: disable=protected-access -import os -import inspect import enum -import yaml import importlib +import inspect +import os import sys -from os import kill, getpid +from os import getpid, kill from signal import SIGTERM -from flask import Flask -from flask import Response +import yaml +from flask import Flask, Response +from firebase_functions import options as _options +from firebase_functions import params as _params from firebase_functions.private import manifest as _manifest -from firebase_functions import params as _params, options as _options from firebase_functions.private import util as _util @@ -52,7 +53,6 @@ def get_functions(): def to_spec(data: dict) -> dict: - def convert_value(obj): if isinstance(obj, enum.Enum): return obj.value @@ -62,13 +62,12 @@ def convert_value(obj): return list(map(convert_value, obj)) return obj - without_nones = dict( - (k, convert_value(v)) for k, v in data.items() if v is not None) + without_nones = {k: convert_value(v) for k, v in data.items() if v is not None} return without_nones def merge_required_apis( - required_apis: list[_manifest.ManifestRequiredApi] + required_apis: list[_manifest.ManifestRequiredApi], ) -> list[_manifest.ManifestRequiredApi]: api_to_reasons: dict[str, list[str]] = {} for api_reason in required_apis: @@ -125,7 +124,6 @@ def get_functions_yaml() -> Response: def quitquitquit(): - def quit_after_close(): kill(getpid(), SIGTERM) diff --git a/src/firebase_functions/private/token_verifier.py b/src/firebase_functions/private/token_verifier.py index a986ec47..096b1135 100644 --- a/src/firebase_functions/private/token_verifier.py +++ b/src/firebase_functions/private/token_verifier.py @@ -14,11 +14,20 @@ """ Module for internal token verification. """ -from firebase_admin import _token_gen, exceptions, _auth_utils, initialize_app, get_app, _apps, _DEFAULT_APP_NAME -from google.auth import jwt + import google.auth.exceptions import google.oauth2.id_token import google.oauth2.service_account +from firebase_admin import ( + _DEFAULT_APP_NAME, + _apps, + _auth_utils, + _token_gen, + exceptions, + get_app, + initialize_app, +) +from google.auth import jwt # pylint: disable=consider-using-f-string @@ -31,100 +40,91 @@ class _JWTVerifier: """Verifies Firebase JWTs (ID tokens or session cookies).""" def __init__(self, **kwargs): - self.project_id = kwargs.pop('project_id') - self.short_name = kwargs.pop('short_name') - self.operation = kwargs.pop('operation') - self.url = kwargs.pop('doc_url') - self.cert_url = kwargs.pop('cert_url') - self.issuer = kwargs.pop('issuer') - self.expected_audience = kwargs.pop('expected_audience') - if self.short_name[0].lower() in 'aeiou': - self.articled_short_name = 'an {0}'.format(self.short_name) + self.project_id = kwargs.pop("project_id") + self.short_name = kwargs.pop("short_name") + self.operation = kwargs.pop("operation") + self.url = kwargs.pop("doc_url") + self.cert_url = kwargs.pop("cert_url") + self.issuer = kwargs.pop("issuer") + self.expected_audience = kwargs.pop("expected_audience") + if self.short_name[0].lower() in "aeiou": + self.articled_short_name = f"an {self.short_name}" else: - self.articled_short_name = 'a {0}'.format(self.short_name) - self._invalid_token_error = kwargs.pop('invalid_token_error') - self._expired_token_error = kwargs.pop('expired_token_error') + self.articled_short_name = f"a {self.short_name}" + self._invalid_token_error = kwargs.pop("invalid_token_error") + self._expired_token_error = kwargs.pop("expired_token_error") def verify(self, token, request): """Verifies the signature and data for the provided JWT.""" - token = token.encode('utf-8') if isinstance(token, str) else token + token = token.encode("utf-8") if isinstance(token, str) else token if not isinstance(token, bytes) or not token: raise ValueError( - 'Illegal {0} provided: {1}. {0} must be a non-empty ' - 'string.'.format(self.short_name, token)) + f"Illegal {self.short_name} provided: {token}. {self.short_name} must be a non-empty string." + ) if not self.project_id: raise ValueError( - 'Failed to ascertain project ID from the credential or the environment. Project ' - 'ID is required to call {0}. Initialize the app with a credentials.Certificate ' - 'or set your Firebase project ID as an app option. Alternatively set the ' - 'GOOGLE_CLOUD_PROJECT environment variable.'.format( - self.operation)) + "Failed to ascertain project ID from the credential or the environment. Project " + f"ID is required to call {self.operation}. Initialize the app with a credentials.Certificate " + "or set your Firebase project ID as an app option. Alternatively set the " + "GOOGLE_CLOUD_PROJECT environment variable." + ) header, payload = self._decode_unverified(token) - issuer = payload.get('iss') - audience = payload.get('aud') - subject = payload.get('sub') + issuer = payload.get("iss") + audience = payload.get("aud") + subject = payload.get("sub") expected_issuer = self.issuer + self.project_id project_id_match_msg = ( - 'Make sure the {0} comes from the same Firebase project as the service account used ' - 'to authenticate this SDK.'.format(self.short_name)) - verify_id_token_msg = ( - 'See {0} for details on how to retrieve {1}.'.format( - self.url, self.short_name)) + f"Make sure the {self.short_name} comes from the same Firebase project as the service account used " + "to authenticate this SDK." + ) + verify_id_token_msg = f"See {self.url} for details on how to retrieve {self.short_name}." emulated = _auth_utils.is_emulated() error_message = None if audience == _token_gen.FIREBASE_AUDIENCE: - error_message = ('{0} expects {1}, but was given a custom ' - 'token.'.format(self.operation, - self.articled_short_name)) - elif not emulated and not header.get('kid'): - if header.get('alg') == 'HS256' and payload.get( - 'v') == 0 and 'uid' in payload.get('d', {}): - error_message = ( - '{0} expects {1}, but was given a legacy custom ' - 'token.'.format(self.operation, self.articled_short_name)) + error_message = f"{self.operation} expects {self.articled_short_name}, but was given a custom token." + elif not emulated and not header.get("kid"): + if ( + header.get("alg") == "HS256" + and payload.get("v") == 0 + and "uid" in payload.get("d", {}) + ): + error_message = f"{self.operation} expects {self.articled_short_name}, but was given a legacy custom token." else: - error_message = 'Firebase {0} has no "kid" claim.'.format( - self.short_name) - elif not emulated and header.get('alg') != 'RS256': + error_message = f'Firebase {self.short_name} has no "kid" claim.' + elif not emulated and header.get("alg") != "RS256": error_message = ( - 'Firebase {0} has incorrect algorithm. Expected "RS256" but got ' - '"{1}". {2}'.format(self.short_name, header.get('alg'), - verify_id_token_msg)) + 'Firebase {} has incorrect algorithm. Expected "RS256" but got "{}". {}'.format( + self.short_name, header.get("alg"), verify_id_token_msg + ) + ) elif not emulated and self.expected_audience and self.expected_audience not in audience: error_message = ( - 'Firebase {0} has incorrect "aud" (audience) claim. Expected "{1}" but ' - 'got "{2}". {3} {4}'.format(self.short_name, - self.expected_audience, audience, - project_id_match_msg, - verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect "aud" (audience) claim. Expected "{self.expected_audience}" but ' + f'got "{audience}". {project_id_match_msg} {verify_id_token_msg}' + ) elif not emulated and not self.expected_audience and audience != self.project_id: error_message = ( - 'Firebase {0} has incorrect "aud" (audience) claim. Expected "{1}" but ' - 'got "{2}". {3} {4}'.format(self.short_name, self.project_id, - audience, project_id_match_msg, - verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect "aud" (audience) claim. Expected "{self.project_id}" but ' + f'got "{audience}". {project_id_match_msg} {verify_id_token_msg}' + ) elif issuer != expected_issuer: error_message = ( - 'Firebase {0} has incorrect "iss" (issuer) claim. Expected "{1}" but ' - 'got "{2}". {3} {4}'.format(self.short_name, expected_issuer, - issuer, project_id_match_msg, - verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect "iss" (issuer) claim. Expected "{expected_issuer}" but ' + f'got "{issuer}". {project_id_match_msg} {verify_id_token_msg}' + ) elif subject is None or not isinstance(subject, str): - error_message = ('Firebase {0} has no "sub" (subject) claim. ' - '{1}'.format(self.short_name, verify_id_token_msg)) - elif not subject: error_message = ( - 'Firebase {0} has an empty string "sub" (subject) claim. ' - '{1}'.format(self.short_name, verify_id_token_msg)) + f'Firebase {self.short_name} has no "sub" (subject) claim. {verify_id_token_msg}' + ) + elif not subject: + error_message = f'Firebase {self.short_name} has an empty string "sub" (subject) claim. {verify_id_token_msg}' elif len(subject) > 128: - error_message = ( - 'Firebase {0} has a "sub" (subject) claim longer than 128 characters. ' - '{1}'.format(self.short_name, verify_id_token_msg)) + error_message = f'Firebase {self.short_name} has a "sub" (subject) claim longer than 128 characters. {verify_id_token_msg}' if error_message: raise self._invalid_token_error(error_message) @@ -138,17 +138,17 @@ def verify(self, token, request): request=request, # If expected_audience is set then we have already verified # the audience above. - audience=(None - if self.expected_audience else self.project_id), - certs_url=self.cert_url) - verified_claims['uid'] = verified_claims['sub'] + audience=(None if self.expected_audience else self.project_id), + certs_url=self.cert_url, + ) + verified_claims["uid"] = verified_claims["sub"] return verified_claims except google.auth.exceptions.TransportError as error: - raise _token_gen.CertificateFetchError(str(error), cause=error) + raise _token_gen.CertificateFetchError(str(error), cause=error) # noqa: B904 except ValueError as error: - if 'Token expired' in str(error): - raise self._expired_token_error(str(error), cause=error) - raise self._invalid_token_error(str(error), cause=error) + if "Token expired" in str(error): + raise self._expired_token_error(str(error), cause=error) # noqa: B904 + raise self._invalid_token_error(str(error), cause=error) # noqa: B904 def _decode_unverified(self, token): try: @@ -156,17 +156,16 @@ def _decode_unverified(self, token): payload = jwt.decode(token, verify=False) return header, payload except ValueError as error: - raise self._invalid_token_error(str(error), cause=error) + raise self._invalid_token_error(str(error), cause=error) # noqa: B904 class InvalidAuthBlockingTokenError(exceptions.InvalidArgumentError): """The provided auth blocking token is not a token.""" - default_message = 'The provided auth blocking token is invalid' + default_message = "The provided auth blocking token is invalid" def __init__(self, message, cause=None, http_response=None): - exceptions.InvalidArgumentError.__init__(self, message, cause, - http_response) + exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) class ExpiredAuthBlockingTokenError(InvalidAuthBlockingTokenError): @@ -183,15 +182,14 @@ def __init__(self, app): super().__init__(app) self.auth_blocking_token_verifier = _JWTVerifier( project_id=app.project_id, - short_name='Auth Blocking token', - operation='verify_auth_blocking_token()', - doc_url= - 'https://cloud.google.com/identity-platform/docs/blocking-functions', + short_name="Auth Blocking token", + operation="verify_auth_blocking_token()", + doc_url="https://cloud.google.com/identity-platform/docs/blocking-functions", cert_url=_token_gen.ID_TOKEN_CERT_URI, issuer=_token_gen.ID_TOKEN_ISSUER_PREFIX, invalid_token_error=InvalidAuthBlockingTokenError, expired_token_error=ExpiredAuthBlockingTokenError, - expected_audience='run.app', # v2 only + expected_audience="run.app", # v2 only ) def verify_auth_blocking_token(self, auth_blocking_token): @@ -205,5 +203,4 @@ def verify_auth_blocking_token(auth_blocking_token): """Verifies the provided auth blocking token.""" if _DEFAULT_APP_NAME not in _apps: initialize_app() - return AuthBlockingTokenVerifier( - get_app()).verify_auth_blocking_token(auth_blocking_token) + return AuthBlockingTokenVerifier(get_app()).verify_auth_blocking_token(auth_blocking_token) diff --git a/src/firebase_functions/private/util.py b/src/firebase_functions/private/util.py index b24ea26c..9df09035 100644 --- a/src/firebase_functions/private/util.py +++ b/src/firebase_functions/private/util.py @@ -16,23 +16,23 @@ """ import base64 -import os as _os -import json as _json -import re as _re -import typing as _typing import dataclasses as _dataclasses import datetime as _dt import enum as _enum +import json as _json +import os as _os +import re as _re +import typing as _typing + +from firebase_admin import app_check as _app_check +from firebase_admin import auth as _auth from flask import Request as _Request from functions_framework import logging as _logging -from firebase_admin import auth as _auth -from firebase_admin import app_check as _app_check P = _typing.ParamSpec("P") R = _typing.TypeVar("R") -JWT_REGEX = _re.compile( - r"^[a-zA-Z0-9\-_=]+?\.[a-zA-Z0-9\-_=]+?\.([a-zA-Z0-9\-_=]+)?$") +JWT_REGEX = _re.compile(r"^[a-zA-Z0-9\-_=]+?\.[a-zA-Z0-9\-_=]+?\.([a-zA-Z0-9\-_=]+)?$") class Sentinel: @@ -41,15 +41,16 @@ class Sentinel: def __init__(self, description): self.description = description + def __hash__(self): + return hash(self.description) + def __eq__(self, other): - return isinstance(other, - Sentinel) and self.description == other.description + return isinstance(other, Sentinel) and self.description == other.description def copy_func_kwargs( func_with_kwargs: _typing.Callable[P, _typing.Any], # pylint: disable=unused-argument ) -> _typing.Callable[[_typing.Callable[..., R]], _typing.Callable[P, R]]: - def return_func(func: _typing.Callable[..., R]) -> _typing.Callable[P, R]: return _typing.cast(_typing.Callable[P, R], func) @@ -60,7 +61,7 @@ def set_func_endpoint_attr( func: _typing.Callable[P, _typing.Any], endpoint: _typing.Any, ) -> _typing.Callable[P, _typing.Any]: - setattr(func, "__firebase_endpoint__", endpoint) + func.__firebase_endpoint__ = endpoint # type: ignore return func @@ -95,16 +96,16 @@ def deep_merge(dict1, dict2): def valid_on_call_request(request: _Request) -> bool: """Validate request""" - if (_on_call_valid_method(request) and - _on_call_valid_content_type(request) and - _on_call_valid_body(request)): + if ( + _on_call_valid_method(request) + and _on_call_valid_content_type(request) + and _on_call_valid_body(request) + ): return True return False -def convert_keys_to_camel_case( - data: dict[str, _typing.Any]) -> dict[str, _typing.Any]: - +def convert_keys_to_camel_case(data: dict[str, _typing.Any]) -> dict[str, _typing.Any]: def snake_to_camel(word: str) -> str: components = word.split("_") return components[0] + "".join(x.capitalize() for x in components[1:]) @@ -123,9 +124,7 @@ def _on_call_valid_body(request: _Request) -> bool: _logging.warning("Request body is missing data.", request.json) return False - extra_keys = { - key: request.json[key] for key in request.json.keys() if key != "data" - } + extra_keys = {key: request.json[key] for key in request.json.keys() if key != "data"} if len(extra_keys) != 0: _logging.warning( "Request body has extra fields: %s", @@ -212,11 +211,11 @@ def as_dict(self) -> dict: def _on_call_check_auth_token( - request: _Request + request: _Request, ) -> None | _typing.Literal[OnCallTokenState.INVALID] | dict[str, _typing.Any]: """ - Validates the auth token in a callable request. - If verify_token is False, the token will be decoded without verification. + Validates the auth token in a callable request. + If verify_token is False, the token will be decoded without verification. """ authorization = request.headers.get("Authorization") if authorization is None: @@ -235,7 +234,7 @@ def _on_call_check_auth_token( def _on_call_check_app_token( - request: _Request + request: _Request, ) -> None | _typing.Literal[OnCallTokenState.INVALID] | dict[str, _typing.Any]: """Validates the app token in a callable request.""" app_check = request.headers.get("X-Firebase-AppCheck") @@ -304,14 +303,13 @@ def on_call_check_tokens(request: _Request) -> _OnCallTokenVerification: if len(errs) == 0: _logging.info("Callable request verification passed: %s", log_payload) else: - _logging.warning(f"Callable request verification failed: ${errs}", - log_payload) + _logging.warning(f"Callable request verification failed: ${errs}", log_payload) return verifications @_dataclasses.dataclass(frozen=True) -class FirebaseConfig(): +class FirebaseConfig: """ A collection of configuration options needed to initialize a firebase App. @@ -337,11 +335,10 @@ def firebase_config() -> None | FirebaseConfig: # explicitly state that the user can set the env to a file: # https://firebase.google.com/docs/admin/setup#initialize-without-parameters try: - with open(config_file, "r", encoding="utf8") as json_file: + with open(config_file, encoding="utf8") as json_file: json_str = json_file.read() except Exception as err: - raise ValueError( - f"Unable to read file {config_file}. {err}") from err + raise ValueError(f"Unable to read file {config_file}. {err}") from err try: json_data: dict = _json.loads(json_str) except Exception as err: @@ -355,13 +352,11 @@ def nanoseconds_timestamp_conversion(time: str) -> _dt.datetime: """Converts a nanosecond timestamp and returns a datetime object of the current time in UTC""" # Separate the date and time part from the nanoseconds. - datetime_str, nanosecond_str = time.replace("Z", "").replace("z", - "").split(".") + datetime_str, nanosecond_str = time.replace("Z", "").replace("z", "").split(".") # Parse the date and time part of the string. event_time = _dt.datetime.strptime(datetime_str, "%Y-%m-%dT%H:%M:%S") # Add the microseconds and timezone. - event_time = event_time.replace(microsecond=int(nanosecond_str[:6]), - tzinfo=_dt.timezone.utc) + event_time = event_time.replace(microsecond=int(nanosecond_str[:6]), tzinfo=_dt.timezone.utc) return event_time @@ -398,8 +393,7 @@ def get_precision_timestamp(time: str) -> PrecisionTimestamp: return PrecisionTimestamp.SECONDS # Split the fraction from the timezone specifier ('Z' or 'z') - s_fraction, _ = s_fraction.split( - "Z") if "Z" in s_fraction else s_fraction.split("z") + s_fraction, _ = s_fraction.split("Z") if "Z" in s_fraction else s_fraction.split("z") # If the fraction is more than 6 digits long, it's a nanosecond timestamp if len(s_fraction) > 6: diff --git a/src/firebase_functions/pubsub_fn.py b/src/firebase_functions/pubsub_fn.py index ebac90ef..7297599d 100644 --- a/src/firebase_functions/pubsub_fn.py +++ b/src/firebase_functions/pubsub_fn.py @@ -14,17 +14,18 @@ """ Functions to handle events from Google Cloud Pub/Sub. """ + # pylint: disable=protected-access +import base64 as _base64 import dataclasses as _dataclasses import datetime as _dt import functools as _functools -import typing as _typing import json as _json -import base64 as _base64 +import typing as _typing + import cloudevents.http as _ce import firebase_functions.private.util as _util - from firebase_functions.core import CloudEvent, T, _with_init from firebase_functions.options import PubSubOptions @@ -68,9 +69,7 @@ def json(self) -> T | None: else: return None except Exception as error: - raise ValueError( - f"Unable to parse Pub/Sub message data as JSON: {error}" - ) from error + raise ValueError(f"Unable to parse Pub/Sub message data as JSON: {error}") from error @_dataclasses.dataclass(frozen=True) @@ -80,6 +79,7 @@ class MessagePublishedData(_typing.Generic[T]): 'T' Type representing `Message.data`'s JSON format. """ + message: Message[T] """ Google Cloud Pub/Sub message. @@ -109,8 +109,7 @@ def _message_handler( if "." not in event_dict["time"]: event_dict["time"] = event_dict["time"].replace("Z", ".000000Z") if "." not in message_dict["publish_time"]: - message_dict["publish_time"] = message_dict["publish_time"].replace( - "Z", ".000000Z") + message_dict["publish_time"] = message_dict["publish_time"].replace("Z", ".000000Z") time = _dt.datetime.strptime( event_dict["time"], @@ -185,7 +184,6 @@ def example(event: CloudEvent[MessagePublishedData[object]]) -> None: options = PubSubOptions(**kwargs) def on_message_published_inner_decorator(func: _C1): - @_functools.wraps(func) def on_message_published_wrapped(raw: _ce.CloudEvent): return _message_handler(func, raw) diff --git a/src/firebase_functions/remote_config_fn.py b/src/firebase_functions/remote_config_fn.py index c48436d5..bb48aa9b 100644 --- a/src/firebase_functions/remote_config_fn.py +++ b/src/firebase_functions/remote_config_fn.py @@ -15,15 +15,16 @@ """ Cloud functions to handle Remote Config events. """ + import dataclasses as _dataclasses -import functools as _functools import datetime as _dt +import enum as _enum +import functools as _functools import typing as _typing + import cloudevents.http as _ce -import enum as _enum import firebase_functions.private.util as _util - from firebase_functions.core import CloudEvent, _with_init from firebase_functions.options import EventHandlerOptions @@ -163,8 +164,7 @@ def _config_handler(func: _C1, raw: _ce.CloudEvent) -> None: config_data = ConfigUpdateData( version_number=event_data["versionNumber"], - update_time=_dt.datetime.strptime(event_data["updateTime"], - "%Y-%m-%dT%H:%M:%S.%f%z"), + update_time=_dt.datetime.strptime(event_data["updateTime"], "%Y-%m-%dT%H:%M:%S.%f%z"), update_user=ConfigUser( name=event_data["updateUser"]["name"], email=event_data["updateUser"]["email"], @@ -216,7 +216,6 @@ def example(event: CloudEvent[ConfigUpdateData]) -> None: options = EventHandlerOptions(**kwargs) def on_config_updated_inner_decorator(func: _C1): - @_functools.wraps(func) def on_config_updated_wrapped(raw: _ce.CloudEvent): return _config_handler(func, raw) @@ -226,7 +225,7 @@ def on_config_updated_wrapped(raw: _ce.CloudEvent): options._endpoint( func_name=func.__name__, event_filters={}, - event_type="google.firebase.remoteconfig.remoteConfig.v1.updated" + event_type="google.firebase.remoteconfig.remoteConfig.v1.updated", ), ) return on_config_updated_wrapped diff --git a/src/firebase_functions/scheduler_fn.py b/src/firebase_functions/scheduler_fn.py index c5a92c99..1979f67e 100644 --- a/src/firebase_functions/scheduler_fn.py +++ b/src/firebase_functions/scheduler_fn.py @@ -13,24 +13,29 @@ # limitations under the License. """Cloud functions to handle Schedule triggers.""" -import typing as _typing import dataclasses as _dataclasses import datetime as _dt import functools as _functools +import typing as _typing -import firebase_functions.options as _options -import firebase_functions.private.util as _util -from functions_framework import logging as _logging from flask import ( Request as _Request, +) +from flask import ( Response as _Response, +) +from flask import ( make_response as _make_response, ) +from functions_framework import logging as _logging +import firebase_functions.options as _options +import firebase_functions.private.util as _util from firebase_functions.core import _with_init -# Export for user convenience. -# pylint: disable=unused-import -from firebase_functions.options import Timezone + +# Re-export Timezone from options module so users can import it directly from scheduler_fn +# This provides a more convenient API: from firebase_functions.scheduler_fn import Timezone +from firebase_functions.options import Timezone # noqa: F401 @_dataclasses.dataclass(frozen=True) @@ -91,12 +96,10 @@ def example(event: scheduler_fn.ScheduledEvent) -> None: options = _options.ScheduleOptions(**kwargs) def on_schedule_decorator(func: _C): - @_functools.wraps(func) def on_schedule_wrapped(request: _Request) -> _Response: schedule_time: _dt.datetime - schedule_time_str = request.headers.get( - "X-CloudScheduler-ScheduleTime") + schedule_time_str = request.headers.get("X-CloudScheduler-ScheduleTime") if schedule_time_str is None: schedule_time = _dt.datetime.utcnow() else: diff --git a/src/firebase_functions/storage_fn.py b/src/firebase_functions/storage_fn.py index 342d2573..29895678 100644 --- a/src/firebase_functions/storage_fn.py +++ b/src/firebase_functions/storage_fn.py @@ -14,11 +14,13 @@ """ Functions to handle events from Google Cloud Storage. """ + # pylint: disable=protected-access import dataclasses as _dataclasses import datetime as _dt import functools as _functools import typing as _typing + import cloudevents.http as _ce import firebase_functions.private.util as _util @@ -235,10 +237,11 @@ def _message_handler( updated=data.get("updated"), # Custom type fields: customer_encryption=CustomerEncryption( - encryption_algorithm=data["customerEncryption"] - ["encryptionAlgorithm"], + encryption_algorithm=data["customerEncryption"]["encryptionAlgorithm"], key_sha256=data["customerEncryption"]["keySha256"], - ) if data.get("customerEncryption") is not None else None, + ) + if data.get("customerEncryption") is not None + else None, ) event: CloudEvent[StorageObjectData] = CloudEvent( @@ -246,8 +249,7 @@ def _message_handler( id=event_attributes["id"], source=event_attributes["source"], specversion=event_attributes["specversion"], - subject=event_attributes["subject"] - if "subject" in event_attributes else None, + subject=event_attributes["subject"] if "subject" in event_attributes else None, time=_dt.datetime.strptime( event_attributes["time"], "%Y-%m-%dT%H:%M:%S.%f%z", @@ -284,15 +286,13 @@ def example(event: CloudEvent[StorageObjectData]) -> None: options = StorageOptions(**kwargs) def on_object_archived_inner_decorator(func: _C1): - @_functools.wraps(func) def on_object_archived_wrapped(raw: _ce.CloudEvent): return _message_handler(func, raw) _util.set_func_endpoint_attr( on_object_archived_wrapped, - options._endpoint(func_name=func.__name__, - event_type=_event_type_archived), + options._endpoint(func_name=func.__name__, event_type=_event_type_archived), ) return on_object_archived_wrapped @@ -326,15 +326,13 @@ def example(event: CloudEvent[StorageObjectData]) -> None: options = StorageOptions(**kwargs) def on_object_finalized_inner_decorator(func: _C1): - @_functools.wraps(func) def on_object_finalized_wrapped(raw: _ce.CloudEvent): return _message_handler(func, raw) _util.set_func_endpoint_attr( on_object_finalized_wrapped, - options._endpoint(func_name=func.__name__, - event_type=_event_type_finalized), + options._endpoint(func_name=func.__name__, event_type=_event_type_finalized), ) return on_object_finalized_wrapped @@ -369,15 +367,13 @@ def example(event: CloudEvent[StorageObjectData]) -> None: options = StorageOptions(**kwargs) def on_object_deleted_inner_decorator(func: _C1): - @_functools.wraps(func) def on_object_deleted_wrapped(raw: _ce.CloudEvent): return _message_handler(func, raw) _util.set_func_endpoint_attr( on_object_deleted_wrapped, - options._endpoint(func_name=func.__name__, - event_type=_event_type_deleted), + options._endpoint(func_name=func.__name__, event_type=_event_type_deleted), ) return on_object_deleted_wrapped @@ -408,15 +404,13 @@ def example(event: CloudEvent[StorageObjectData]) -> None: options = StorageOptions(**kwargs) def on_object_metadata_updated_inner_decorator(func: _C1): - @_functools.wraps(func) def on_object_metadata_updated_wrapped(raw: _ce.CloudEvent): return _message_handler(func, raw) _util.set_func_endpoint_attr( on_object_metadata_updated_wrapped, - options._endpoint(func_name=func.__name__, - event_type=_event_type_metadata_updated), + options._endpoint(func_name=func.__name__, event_type=_event_type_metadata_updated), ) return on_object_metadata_updated_wrapped diff --git a/src/firebase_functions/tasks_fn.py b/src/firebase_functions/tasks_fn.py index 7b8c6750..1fea65c0 100644 --- a/src/firebase_functions/tasks_fn.py +++ b/src/firebase_functions/tasks_fn.py @@ -14,19 +14,20 @@ """Functions to handle Tasks enqueued with Google Cloud Tasks.""" # pylint: disable=protected-access -import typing as _typing -import functools as _functools import dataclasses as _dataclasses +import functools as _functools import json as _json +import typing as _typing -from flask import Request, Response, make_response as _make_response, jsonify as _jsonify +from flask import Request, Response +from flask import jsonify as _jsonify +from flask import make_response as _make_response +from functions_framework import logging as _logging import firebase_functions.core as _core import firebase_functions.options as _options import firebase_functions.private.util as _util -from firebase_functions.https_fn import CallableRequest, HttpsError, FunctionsErrorCode - -from functions_framework import logging as _logging +from firebase_functions.https_fn import CallableRequest, FunctionsErrorCode, HttpsError _C = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] _C1 = _typing.Callable[[Request], Response] @@ -51,8 +52,7 @@ def _on_call_handler(func: _C2, request: Request) -> Response: # pushes with FCM. In that case, the FCM APIs will validate the token. context = _dataclasses.replace( context, - instance_id_token=request.headers.get( - "Firebase-Instance-ID-Token"), + instance_id_token=request.headers.get("Firebase-Instance-ID-Token"), ) result = _core._with_init(func)(context) return _jsonify(result=result) @@ -91,7 +91,6 @@ def example(request: tasks.CallableRequest) -> Any: options = _options.TaskQueueOptions(**kwargs) def on_task_dispatched_decorator(func: _C): - @_functools.wraps(func) def on_task_dispatched_wrapped(request: Request) -> Response: return _on_call_handler(func, request) diff --git a/src/firebase_functions/test_lab_fn.py b/src/firebase_functions/test_lab_fn.py index 7aede95e..9ca8cdb5 100644 --- a/src/firebase_functions/test_lab_fn.py +++ b/src/firebase_functions/test_lab_fn.py @@ -15,15 +15,16 @@ """ Cloud functions to handle Test Lab events. """ + import dataclasses as _dataclasses -import functools as _functools import datetime as _dt +import enum as _enum +import functools as _functools import typing as _typing + import cloudevents.http as _ce -import enum as _enum import firebase_functions.private.util as _util - from firebase_functions.core import CloudEvent, _with_init from firebase_functions.options import EventHandlerOptions @@ -213,18 +214,15 @@ def _event_handler(func: _C1, raw: _ce.CloudEvent) -> None: event_dict = {**event_data, **event_attributes} test_lab_data = TestMatrixCompletedData( - create_time=_dt.datetime.strptime(event_data["createTime"], - "%Y-%m-%dT%H:%M:%S.%f%z"), + create_time=_dt.datetime.strptime(event_data["createTime"], "%Y-%m-%dT%H:%M:%S.%f%z"), state=TestState(event_data["state"]), invalid_matrix_details=event_data.get("invalidMatrixDetails"), outcome_summary=OutcomeSummary(event_data["outcomeSummary"]), result_storage=ResultStorage( - tool_results_history=event_data["resultStorage"] - ["toolResultsHistory"], + tool_results_history=event_data["resultStorage"]["toolResultsHistory"], results_uri=event_data["resultStorage"]["resultsUri"], gcs_path=event_data["resultStorage"]["gcsPath"], - tool_results_execution=event_data["resultStorage"].get( - "toolResultsExecution"), + tool_results_execution=event_data["resultStorage"].get("toolResultsExecution"), ), client_info=ClientInfo( client=event_data["clientInfo"]["client"], @@ -273,7 +271,6 @@ def example(event: CloudEvent[ConfigUpdateData]) -> None: options = EventHandlerOptions(**kwargs) def on_test_matrix_completed_inner_decorator(func: _C1): - @_functools.wraps(func) def on_test_matrix_completed_wrapped(raw: _ce.CloudEvent): return _event_handler(func, raw) @@ -283,7 +280,8 @@ def on_test_matrix_completed_wrapped(raw: _ce.CloudEvent): options._endpoint( func_name=func.__name__, event_filters={}, - event_type="google.firebase.testlab.testMatrix.v1.completed"), + event_type="google.firebase.testlab.testMatrix.v1.completed", + ), ) return on_test_matrix_completed_wrapped diff --git a/tests/test_db.py b/tests/test_db.py index 4e8b487d..59af3378 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -4,7 +4,9 @@ import unittest from unittest import mock + from cloudevents.http import CloudEvent + from firebase_functions import core, db_fn @@ -24,19 +26,21 @@ def init(): func = mock.Mock(__name__="example_func") decorated_func = db_fn.on_value_created(reference="path")(func) - event = CloudEvent(attributes={ - "specversion": "1.0", - "id": "id", - "source": "source", - "subject": "subject", - "type": "type", - "time": "2024-04-10T12:00:00.000Z", - "instance": "instance", - "ref": "ref", - "firebasedatabasehost": "firebasedatabasehost", - "location": "location", - }, - data={"delta": "delta"}) + event = CloudEvent( + attributes={ + "specversion": "1.0", + "id": "id", + "source": "source", + "subject": "subject", + "type": "type", + "time": "2024-04-10T12:00:00.000Z", + "instance": "instance", + "ref": "ref", + "firebasedatabasehost": "firebasedatabasehost", + "location": "location", + }, + data={"delta": "delta"}, + ) decorated_func(event) diff --git a/tests/test_eventarc_fn.py b/tests/test_eventarc_fn.py index 730812a2..eb0d76a5 100644 --- a/tests/test_eventarc_fn.py +++ b/tests/test_eventarc_fn.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Eventarc trigger function tests.""" + import unittest from unittest.mock import Mock @@ -38,7 +39,7 @@ def test_on_custom_event_published_decorator(self): event_type="firebase.extensions.storage-resize-images.v1.complete", )(func) - endpoint = getattr(decorated_func, "__firebase_endpoint__") + endpoint = decorated_func.__firebase_endpoint__ self.assertIsNotNone(endpoint) self.assertIsNotNone(endpoint.eventTrigger) self.assertEqual( diff --git a/tests/test_firestore_fn.py b/tests/test_firestore_fn.py index fa4ed156..0e07b102 100644 --- a/tests/test_firestore_fn.py +++ b/tests/test_firestore_fn.py @@ -9,7 +9,7 @@ mocked_modules = { "google.cloud.firestore": MagicMock(), "google.cloud.firestore_v1": MagicMock(), - "firebase_admin": MagicMock() + "firebase_admin": MagicMock(), } @@ -21,47 +21,41 @@ class TestFirestore(TestCase): def test_firestore_endpoint_handler_calls_function_with_correct_args(self): with patch.dict("sys.modules", mocked_modules): from cloudevents.http import CloudEvent - from firebase_functions.firestore_fn import _event_type_created_with_auth_context as event_type, \ - _firestore_endpoint_handler as firestore_endpoint_handler, AuthEvent + + from firebase_functions.firestore_fn import ( + AuthEvent, + ) + from firebase_functions.firestore_fn import ( + _event_type_created_with_auth_context as event_type, + ) + from firebase_functions.firestore_fn import ( + _firestore_endpoint_handler as firestore_endpoint_handler, + ) from firebase_functions.private import path_pattern func = Mock(__name__="example_func") document_pattern = path_pattern.PathPattern("foo/{bar}") attributes = { - "specversion": - "1.0", - "type": - event_type, - "source": - "https://example.com/testevent", - "time": - "2023-03-11T13:25:37.403Z", - "subject": - "test_subject", - "datacontenttype": - "application/json", - "location": - "projects/project-id/databases/(default)/documents/foo/{bar}", - "project": - "project-id", - "namespace": - "(default)", - "document": - "foo/{bar}", - "database": - "projects/project-id/databases/(default)", - "authtype": - "unauthenticated", - "authid": - "foo" + "specversion": "1.0", + "type": event_type, + "source": "https://example.com/testevent", + "time": "2023-03-11T13:25:37.403Z", + "subject": "test_subject", + "datacontenttype": "application/json", + "location": "projects/project-id/databases/(default)/documents/foo/{bar}", + "project": "project-id", + "namespace": "(default)", + "document": "foo/{bar}", + "database": "projects/project-id/databases/(default)", + "authtype": "unauthenticated", + "authid": "foo", } raw_event = CloudEvent(attributes=attributes, data=json.dumps({})) - firestore_endpoint_handler(func=func, - event_type=event_type, - document_pattern=document_pattern, - raw=raw_event) + firestore_endpoint_handler( + func=func, event_type=event_type, document_pattern=document_pattern, raw=raw_event + ) func.assert_called_once() @@ -73,9 +67,10 @@ def test_firestore_endpoint_handler_calls_function_with_correct_args(self): def test_calls_init_function(self): with patch.dict("sys.modules", mocked_modules): - from firebase_functions import firestore_fn, core from cloudevents.http import CloudEvent + from firebase_functions import core, firestore_fn + func = Mock(__name__="example_func") hello = None @@ -86,37 +81,23 @@ def init(): hello = "world" attributes = { - "specversion": - "1.0", + "specversion": "1.0", # pylint: disable=protected-access - "type": - firestore_fn._event_type_created, - "source": - "https://example.com/testevent", - "time": - "2023-03-11T13:25:37.403Z", - "subject": - "test_subject", - "datacontenttype": - "application/json", - "location": - "projects/project-id/databases/(default)/documents/foo/{bar}", - "project": - "project-id", - "namespace": - "(default)", - "document": - "foo/{bar}", - "database": - "projects/project-id/databases/(default)", - "authtype": - "unauthenticated", - "authid": - "foo" + "type": firestore_fn._event_type_created, + "source": "https://example.com/testevent", + "time": "2023-03-11T13:25:37.403Z", + "subject": "test_subject", + "datacontenttype": "application/json", + "location": "projects/project-id/databases/(default)/documents/foo/{bar}", + "project": "project-id", + "namespace": "(default)", + "document": "foo/{bar}", + "database": "projects/project-id/databases/(default)", + "authtype": "unauthenticated", + "authid": "foo", } raw_event = CloudEvent(attributes=attributes, data=json.dumps({})) - decorated_func = firestore_fn.on_document_created( - document="/foo/{bar}")(func) + decorated_func = firestore_fn.on_document_created(document="/foo/{bar}")(func) decorated_func(raw_event) diff --git a/tests/test_https_fn.py b/tests/test_https_fn.py index e128b392..1748b367 100644 --- a/tests/test_https_fn.py +++ b/tests/test_https_fn.py @@ -4,6 +4,7 @@ import unittest from unittest.mock import Mock + from flask import Flask, Request from werkzeug.test import EnvironBuilder @@ -31,9 +32,7 @@ def init(): environ = EnvironBuilder( method="POST", json={ - "data": { - "test": "value" - }, + "data": {"test": "value"}, }, ).get_environ() request = Request(environ) @@ -59,9 +58,7 @@ def init(): environ = EnvironBuilder( method="POST", json={ - "data": { - "test": "value" - }, + "data": {"test": "value"}, }, ).get_environ() request = Request(environ) diff --git a/tests/test_identity_fn.py b/tests/test_identity_fn.py index b71414bc..b3d43fbf 100644 --- a/tests/test_identity_fn.py +++ b/tests/test_identity_fn.py @@ -3,7 +3,8 @@ """ import unittest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import MagicMock, Mock, patch + from flask import Flask, Request from werkzeug.test import EnvironBuilder @@ -12,18 +13,13 @@ token_verifier_mock = MagicMock() token_verifier_mock.verify_auth_blocking_token = Mock( return_value={ - "user_record": { - "uid": "uid", - "metadata": { - "creation_time": 0 - }, - "provider_data": [] - }, + "user_record": {"uid": "uid", "metadata": {"creation_time": 0}, "provider_data": []}, "event_id": "event_id", "ip_address": "ip_address", "user_agent": "user_agent", - "iat": 0 - }) + "iat": 0, + } +) mocked_modules = { "firebase_functions.private.token_verifier": token_verifier_mock, } @@ -45,16 +41,13 @@ def init(): with patch.dict("sys.modules", mocked_modules): app = Flask(__name__) - func = Mock(__name__="example_func", - return_value=identity_fn.BeforeSignInResponse()) + func = Mock(__name__="example_func", return_value=identity_fn.BeforeSignInResponse()) with app.test_request_context("/"): environ = EnvironBuilder( method="POST", json={ - "data": { - "jwt": "jwt" - }, + "data": {"jwt": "jwt"}, }, ).get_environ() request = Request(environ) diff --git a/tests/test_init.py b/tests/test_init.py index 07ee9240..41458357 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -3,6 +3,7 @@ """ import unittest + from firebase_functions import core @@ -12,7 +13,6 @@ class TestInit(unittest.TestCase): """ def test_init_is_initialized(self): - @core.init def fn(): pass diff --git a/tests/test_logger.py b/tests/test_logger.py index 9f995217..8f6aaee6 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -3,8 +3,10 @@ Logger module tests. """ -import pytest import json + +import pytest + from firebase_functions import logger @@ -13,8 +15,7 @@ class TestLogger: Tests for the logger module. """ - def test_format_should_be_valid_json(self, - capsys: pytest.CaptureFixture[str]): + def test_format_should_be_valid_json(self, capsys: pytest.CaptureFixture[str]): logger.log(foo="bar") raw_log_output = capsys.readouterr().out try: @@ -34,8 +35,7 @@ def test_severity_should_be_debug(self, capsys: pytest.CaptureFixture[str]): log_output = json.loads(raw_log_output) assert log_output["severity"] == "DEBUG" - def test_severity_should_be_notice(self, - capsys: pytest.CaptureFixture[str]): + def test_severity_should_be_notice(self, capsys: pytest.CaptureFixture[str]): logger.log(foo="bar") raw_log_output = capsys.readouterr().out log_output = json.loads(raw_log_output) @@ -47,8 +47,7 @@ def test_severity_should_be_info(self, capsys: pytest.CaptureFixture[str]): log_output = json.loads(raw_log_output) assert log_output["severity"] == "INFO" - def test_severity_should_be_warning(self, - capsys: pytest.CaptureFixture[str]): + def test_severity_should_be_warning(self, capsys: pytest.CaptureFixture[str]): logger.warn(foo="bar") raw_log_output = capsys.readouterr().out log_output = json.loads(raw_log_output) @@ -66,23 +65,20 @@ def test_log_should_have_message(self, capsys: pytest.CaptureFixture[str]): log_output = json.loads(raw_log_output) assert "message" in log_output - def test_log_should_have_other_keys(self, - capsys: pytest.CaptureFixture[str]): + def test_log_should_have_other_keys(self, capsys: pytest.CaptureFixture[str]): logger.log(foo="bar") raw_log_output = capsys.readouterr().out log_output = json.loads(raw_log_output) assert "foo" in log_output - def test_message_should_be_space_separated( - self, capsys: pytest.CaptureFixture[str]): + def test_message_should_be_space_separated(self, capsys: pytest.CaptureFixture[str]): logger.log("bar", "qux") expected_message = "bar qux" raw_log_output = capsys.readouterr().out log_output = json.loads(raw_log_output) assert log_output["message"] == expected_message - def test_remove_circular_references(self, - capsys: pytest.CaptureFixture[str]): + def test_remove_circular_references(self, capsys: pytest.CaptureFixture[str]): # Create an object with a circular reference. circ = {"b": "foo"} circ["circ"] = circ @@ -99,15 +95,11 @@ def test_remove_circular_references(self, expected = { "severity": "ERROR", "message": "testing circular", - "circ": { - "b": "foo", - "circ": "[CIRCULAR]" - }, + "circ": {"b": "foo", "circ": "[CIRCULAR]"}, } assert log_output == expected - def test_remove_circular_references_in_arrays( - self, capsys: pytest.CaptureFixture[str]): + def test_remove_circular_references_in_arrays(self, capsys: pytest.CaptureFixture[str]): # Create an object with a circular reference inside an array. circ = {"b": "foo"} circ["circ"] = [circ] @@ -124,15 +116,11 @@ def test_remove_circular_references_in_arrays( expected = { "severity": "ERROR", "message": "testing circular", - "circ": { - "b": "foo", - "circ": ["[CIRCULAR]"] - }, + "circ": {"b": "foo", "circ": ["[CIRCULAR]"]}, } assert log_output == expected - def test_no_false_circular_for_duplicates( - self, capsys: pytest.CaptureFixture[str]): + def test_no_false_circular_for_duplicates(self, capsys: pytest.CaptureFixture[str]): # Ensure that duplicate objects (used in multiple keys) are not marked as circular. obj = {"a": "foo"} entry = { @@ -148,28 +136,17 @@ def test_no_false_circular_for_duplicates( expected = { "severity": "ERROR", "message": "testing circular", - "a": { - "a": "foo" - }, - "b": { - "a": "foo" - }, + "a": {"a": "foo"}, + "b": {"a": "foo"}, } assert log_output == expected - def test_no_false_circular_in_array_duplicates( - self, capsys: pytest.CaptureFixture[str]): + def test_no_false_circular_in_array_duplicates(self, capsys: pytest.CaptureFixture[str]): # Ensure that duplicate objects in arrays are not falsely detected as circular. obj = {"a": "foo"} arr = [ - { - "a": obj, - "b": obj - }, - { - "a": obj, - "b": obj - }, + {"a": obj, "b": obj}, + {"a": obj, "b": obj}, ] entry = { "severity": "ERROR", @@ -182,45 +159,15 @@ def test_no_false_circular_in_array_duplicates( log_output = json.loads(raw_log_output) expected = { - "severity": - "ERROR", - "message": - "testing circular", + "severity": "ERROR", + "message": "testing circular", "a": [ - { - "a": { - "a": "foo" - }, - "b": { - "a": "foo" - } - }, - { - "a": { - "a": "foo" - }, - "b": { - "a": "foo" - } - }, + {"a": {"a": "foo"}, "b": {"a": "foo"}}, + {"a": {"a": "foo"}, "b": {"a": "foo"}}, ], "b": [ - { - "a": { - "a": "foo" - }, - "b": { - "a": "foo" - } - }, - { - "a": { - "a": "foo" - }, - "b": { - "a": "foo" - } - }, + {"a": {"a": "foo"}, "b": {"a": "foo"}}, + {"a": {"a": "foo"}, "b": {"a": "foo"}}, ], } assert log_output == expected diff --git a/tests/test_manifest.py b/tests/test_manifest.py index f948b66f..681d90f5 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -13,8 +13,8 @@ # limitations under the License. """Manifest unit tests.""" -import firebase_functions.private.manifest as _manifest import firebase_functions.params as _params +import firebase_functions.private.manifest as _manifest full_endpoint = _manifest.ManifestEndpoint( platform="gcfv2", @@ -33,9 +33,7 @@ labels={ "hello": "world", }, - secretEnvironmentVariables=[{ - "key": "MY_SECRET" - }], + secretEnvironmentVariables=[{"key": "MY_SECRET"}], ) full_endpoint_dict = { @@ -55,9 +53,7 @@ "labels": { "hello": "world", }, - "secretEnvironmentVariables": [{ - "key": "MY_SECRET" - }], + "secretEnvironmentVariables": [{"key": "MY_SECRET"}], } full_stack = _manifest.ManifestStack( @@ -70,43 +66,29 @@ _params.StringParam("STRING_TEST"), _params.ListParam("LIST_TEST", default=["1", "2", "3"]), ], - requiredAPIs=[{ - "api": "test_api", - "reason": "testing" - }]) + requiredAPIs=[{"api": "test_api", "reason": "testing"}], +) full_stack_dict = { "specVersion": "v1alpha1", - "endpoints": { - "test": full_endpoint_dict - }, - "params": [{ - "name": "BOOL_TEST", - "type": "boolean", - "default": False, - }, { - "name": "INT_TEST", - "type": "int", - "description": "int_description" - }, { - "name": "FLOAT_TEST", - "type": "float", - "immutable": True, - }, { - "name": "SECRET_TEST", - "type": "secret" - }, { - "name": "STRING_TEST", - "type": "string" - }, { - "default": ["1", "2", "3"], - "name": "LIST_TEST", - "type": "list" - }], - "requiredAPIs": [{ - "api": "test_api", - "reason": "testing" - }] + "endpoints": {"test": full_endpoint_dict}, + "params": [ + { + "name": "BOOL_TEST", + "type": "boolean", + "default": False, + }, + {"name": "INT_TEST", "type": "int", "description": "int_description"}, + { + "name": "FLOAT_TEST", + "type": "float", + "immutable": True, + }, + {"name": "SECRET_TEST", "type": "secret"}, + {"name": "STRING_TEST", "type": "string"}, + {"default": ["1", "2", "3"], "name": "LIST_TEST", "type": "list"}, + ], + "requiredAPIs": [{"api": "test_api", "reason": "testing"}], } @@ -116,8 +98,9 @@ class TestManifestStack: def test_stack_to_dict(self): """Generic check that all ManifestStack values convert to dict.""" stack_dict = _manifest.manifest_to_spec_dict(full_stack) - assert (stack_dict == full_stack_dict - ), "Generated manifest spec dict does not match expected dict." + assert stack_dict == full_stack_dict, ( + "Generated manifest spec dict does not match expected dict." + ) class TestManifestEndpoint: @@ -127,38 +110,37 @@ def test_endpoint_to_dict(self): """Generic check that all ManifestEndpoint values convert to dict.""" # pylint: disable=protected-access endpoint_dict = _manifest._dataclass_to_spec(full_endpoint) - assert (endpoint_dict == full_endpoint_dict - ), "Generated endpoint spec dict does not match expected dict." + assert endpoint_dict == full_endpoint_dict, ( + "Generated endpoint spec dict does not match expected dict." + ) def test_endpoint_expressions(self): """Check Expression values convert to CEL strings.""" max_param = _params.IntParam("MAX") expressions_test = _manifest.ManifestEndpoint( - availableMemoryMb=_params.TernaryExpression( - _params.BoolParam("LARGE_BOOL"), 1024, 256), - minInstances=_params.StringParam("LARGE_STR").equals("yes").then( - 6, 1), + availableMemoryMb=_params.TernaryExpression(_params.BoolParam("LARGE_BOOL"), 1024, 256), + minInstances=_params.StringParam("LARGE_STR").equals("yes").then(6, 1), maxInstances=max_param.compare(">", 6).then(6, max_param), timeoutSeconds=_params.IntParam("WORLD"), concurrency=_params.IntParam("BAR"), - vpc={"connector": _params.SecretParam("SECRET")}) + vpc={"connector": _params.SecretParam("SECRET")}, + ) expressions_expected_dict = { "platform": "gcfv2", "region": [], "secretEnvironmentVariables": [], "availableMemoryMb": "{{ params.LARGE_BOOL ? 1024 : 256 }}", - "minInstances": "{{ params.LARGE_STR == \"yes\" ? 6 : 1 }}", + "minInstances": '{{ params.LARGE_STR == "yes" ? 6 : 1 }}', "maxInstances": "{{ params.MAX > 6 ? 6 : params.MAX }}", "timeoutSeconds": "{{ params.WORLD }}", "concurrency": "{{ params.BAR }}", - "vpc": { - "connector": "{{ params.SECRET }}" - } + "vpc": {"connector": "{{ params.SECRET }}"}, } # pylint: disable=protected-access expressions_actual_dict = _manifest._dataclass_to_spec(expressions_test) - assert (expressions_actual_dict == expressions_expected_dict - ), "Generated endpoint spec dict does not match expected dict." + assert expressions_actual_dict == expressions_expected_dict, ( + "Generated endpoint spec dict does not match expected dict." + ) def test_endpoint_nones(self): """Check all None values are removed.""" @@ -175,5 +157,6 @@ def test_endpoint_nones(self): } # pylint: disable=protected-access expressions_actual_dict = _manifest._dataclass_to_spec(expressions_test) - assert (expressions_actual_dict == expressions_expected_dict - ), "Generated endpoint spec dict does not match expected dict." + assert expressions_actual_dict == expressions_expected_dict, ( + "Generated endpoint spec dict does not match expected dict." + ) diff --git a/tests/test_options.py b/tests/test_options.py index baaf64af..2f4ceb87 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -14,10 +14,12 @@ """ Options unit tests. """ -from firebase_functions import options, https_fn -from firebase_functions import params -from firebase_functions.private.serving import functions_as_yaml, merge_required_apis + from pytest import raises + +from firebase_functions import https_fn, options, params +from firebase_functions.private.serving import functions_as_yaml, merge_required_apis + # pylint: disable=protected-access @@ -47,19 +49,18 @@ def test_global_options_merged_with_provider_options(): options.set_global_options(max_instances=66) pubsub_options = options.PubSubOptions(topic="foo") # pylint: disable=unexpected-keyword-arg pubsub_options_dict = pubsub_options._asdict_with_global_options() - assert (pubsub_options_dict["topic"] == "foo" - ), "'topic' property missing from dict" + assert pubsub_options_dict["topic"] == "foo", "'topic' property missing from dict" assert "options" not in pubsub_options_dict, "'options' key should not exist in dict" - assert (pubsub_options_dict["max_instances"] == 66 - ), "provider option did not update using the global option" + assert pubsub_options_dict["max_instances"] == 66, ( + "provider option did not update using the global option" + ) def test_https_options_removes_cors(): """ Testing _HttpsOptions strips out the 'cors' property when converted to a dict. """ - https_options = options.HttpsOptions(cors=options.CorsOptions( - cors_origins="*")) + https_options = options.HttpsOptions(cors=options.CorsOptions(cors_origins="*")) assert https_options.cors.cors_origins == "*", "cors options were not set" https_options_dict = https_options._asdict_with_global_options() assert "cors" not in https_options_dict, "'cors' key should not exist in dict" @@ -71,25 +72,25 @@ def test_options_asdict_uses_cel_representation(): CEL values for manifest representation. """ int_param = params.IntParam("MIN") - https_options_dict = options.HttpsOptions( - min_instances=int_param)._asdict_with_global_options() - assert https_options_dict[ - "min_instances"] == f"{int_param}", "param was not converted to CEL string" + https_options_dict = options.HttpsOptions(min_instances=int_param)._asdict_with_global_options() + assert https_options_dict["min_instances"] == f"{int_param}", ( + "param was not converted to CEL string" + ) def test_options_preserve_external_changes(): """ Testing if setting a global option internally change the values. """ - assert (options._GLOBAL_OPTIONS.preserve_external_changes - is None), "option should not already be set" + assert options._GLOBAL_OPTIONS.preserve_external_changes is None, ( + "option should not already be set" + ) options.set_global_options( preserve_external_changes=False, min_instances=5, ) options_asdict = options._GLOBAL_OPTIONS._asdict_with_global_options() - assert (options_asdict["max_instances"] - is options.RESET_VALUE), "option should be RESET_VALUE" + assert options_asdict["max_instances"] is options.RESET_VALUE, "option should be RESET_VALUE" assert options_asdict["min_instances"] == 5, "option should be set" firebase_functions = { @@ -132,33 +133,15 @@ def test_merge_apis_no_duplicate_apis(): APIs without modification when there is no duplication. """ required_apis = [ - { - "api": "API1", - "reason": "Reason 1" - }, - { - "api": "API2", - "reason": "Reason 2" - }, - { - "api": "API3", - "reason": "Reason 3" - }, + {"api": "API1", "reason": "Reason 1"}, + {"api": "API2", "reason": "Reason 2"}, + {"api": "API3", "reason": "Reason 3"}, ] expected_output = [ - { - "api": "API1", - "reason": "Reason 1" - }, - { - "api": "API2", - "reason": "Reason 2" - }, - { - "api": "API3", - "reason": "Reason 3" - }, + {"api": "API1", "reason": "Reason 1"}, + {"api": "API2", "reason": "Reason 2"}, + {"api": "API3", "reason": "Reason 3"}, ] merged_apis = merge_required_apis(required_apis) @@ -176,48 +159,32 @@ def test_merge_apis_duplicate_apis(): APIs and combines the reasons associated with them. """ required_apis = [ - { - "api": "API1", - "reason": "Reason 1" - }, - { - "api": "API2", - "reason": "Reason 2" - }, - { - "api": "API1", - "reason": "Reason 3" - }, - { - "api": "API2", - "reason": "Reason 4" - }, + {"api": "API1", "reason": "Reason 1"}, + {"api": "API2", "reason": "Reason 2"}, + {"api": "API1", "reason": "Reason 3"}, + {"api": "API2", "reason": "Reason 4"}, ] expected_output = [ - { - "api": "API1", - "reason": "Reason 1 Reason 3" - }, - { - "api": "API2", - "reason": "Reason 2 Reason 4" - }, + {"api": "API1", "reason": "Reason 1 Reason 3"}, + {"api": "API2", "reason": "Reason 2 Reason 4"}, ] merged_apis = merge_required_apis(required_apis) - assert len(merged_apis) == len( - expected_output - ), f"Expected a list of length {len(expected_output)}, but got {len(merged_apis)}" + assert len(merged_apis) == len(expected_output), ( + f"Expected a list of length {len(expected_output)}, but got {len(merged_apis)}" + ) for expected_item in expected_output: - assert (expected_item in merged_apis - ), f"Expected item {expected_item} missing from the merged list" + assert expected_item in merged_apis, ( + f"Expected item {expected_item} missing from the merged list" + ) for actual_item in merged_apis: - assert (actual_item in expected_output - ), f"Unexpected item {actual_item} found in the merged list" + assert actual_item in expected_output, ( + f"Unexpected item {actual_item} found in the merged list" + ) def test_invoker_with_one_element_doesnt_throw(): @@ -226,8 +193,6 @@ def test_invoker_with_one_element_doesnt_throw(): def test_invoker_with_no_element_throws(): with raises( - AssertionError, - match= - "HttpsOptions: Invalid option for invoker - must be a non-empty list." + AssertionError, match="HttpsOptions: Invalid option for invoker - must be a non-empty list." ): options.HttpsOptions(invoker=[])._endpoint(func_name="test") diff --git a/tests/test_params.py b/tests/test_params.py index 92e81038..02a5ed14 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Param unit tests.""" + from os import environ import pytest + from firebase_functions import params @@ -24,32 +26,33 @@ class TestBoolParams: def test_bool_param_value_true_or_false(self): """Testing if bool params correctly returns a true or false value.""" bool_param = params.BoolParam("BOOL_VALUE_TEST1") - for value_true, value_false in zip(["true"], - ["false", "anything", "else"]): + for value_true, value_false in zip(["true"], ["false", "anything", "else"], strict=False): environ["BOOL_VALUE_TEST1"] = value_true - assert (bool_param.value is True), "Failure, params returned False" + assert bool_param.value is True, "Failure, params returned False" environ["BOOL_VALUE_TEST1"] = value_false - assert (bool_param.value is False), "Failure, params returned True" + assert bool_param.value is False, "Failure, params returned True" def test_bool_param_empty_default(self): """Testing if bool params defaults to False if no value and no default.""" - assert (params.BoolParam("BOOL_DEFAULT_TEST").value - is False), "Failure, params returned True" + assert params.BoolParam("BOOL_DEFAULT_TEST").value is False, "Failure, params returned True" def test_bool_param_default(self): """Testing if bool params defaults to provided default value.""" - assert (params.BoolParam("BOOL_DEFAULT_TEST_FALSE", default=False).value - is False), "Failure, params returned True" - assert (params.BoolParam("BOOL_DEFAULT_TEST_TRUE", default=True).value - is True), "Failure, params returned False" + assert params.BoolParam("BOOL_DEFAULT_TEST_FALSE", default=False).value is False, ( + "Failure, params returned True" + ) + assert params.BoolParam("BOOL_DEFAULT_TEST_TRUE", default=True).value is True, ( + "Failure, params returned False" + ) def test_bool_param_equality(self): """Test bool equality.""" - assert (params.BoolParam("BOOL_TEST1", - default=False).equals(False).value - is True), "Failure, equality check returned False" - assert (params.BoolParam("BOOL_TEST2", default=True).equals(False).value - is False), "Failure, equality check returned False" + assert params.BoolParam("BOOL_TEST1", default=False).equals(False).value is True, ( + "Failure, equality check returned False" + ) + assert params.BoolParam("BOOL_TEST2", default=True).equals(False).value is False, ( + "Failure, equality check returned False" + ) class TestFloatParams: @@ -58,28 +61,33 @@ class TestFloatParams: def test_float_param_value(self): """Testing if float params correctly returns a value.""" environ["FLOAT_VALUE_TEST"] = "123.456" - assert params._FloatParam("FLOAT_VALUE_TEST",).value == 123.456, \ - "Failure, params value != 123.456" + assert ( + params._FloatParam( + "FLOAT_VALUE_TEST", + ).value + == 123.456 + ), "Failure, params value != 123.456" def test_float_param_empty_default(self): """Testing if float params defaults to empty float if no value and no default.""" - assert params._FloatParam("FLOAT_DEFAULT_TEST1").value == float(), \ + assert params._FloatParam("FLOAT_DEFAULT_TEST1").value == 0.0, ( "Failure, params value is not float" + ) def test_float_param_default(self): """Testing if float param defaults to provided default value.""" - assert params._FloatParam("FLOAT_DEFAULT_TEST2", \ - default=float(456.789)).value == 456.789, \ + assert params._FloatParam("FLOAT_DEFAULT_TEST2", default=456.789).value == 456.789, ( "Failure, params default value != 456.789" + ) def test_float_param_equality(self): """Test float equality.""" - assert (params._FloatParam("FLOAT_TEST1", \ - default=123.456).equals(123.456).value \ - is True), "Failure, equality check returned False" - assert (params._FloatParam("FLOAT_TEST2", \ - default=456.789).equals(123.456).value \ - is False), "Failure, equality check returned False" + assert params._FloatParam("FLOAT_TEST1", default=123.456).equals(123.456).value is True, ( + "Failure, equality check returned False" + ) + assert params._FloatParam("FLOAT_TEST2", default=456.789).equals(123.456).value is False, ( + "Failure, equality check returned False" + ) class TestIntParams: @@ -88,25 +96,26 @@ class TestIntParams: def test_int_param_value(self): """Testing if int param correctly returns a value.""" environ["INT_VALUE_TEST"] = "123" - assert params.IntParam( - "INT_VALUE_TEST").value == 123, "Failure, params value != 123" + assert params.IntParam("INT_VALUE_TEST").value == 123, "Failure, params value != 123" def test_int_param_empty_default(self): """Testing if int param defaults to empty int if no value and no default.""" - assert params.IntParam("INT_DEFAULT_TEST1").value == int( - ), "Failure, params value is not int" + assert params.IntParam("INT_DEFAULT_TEST1").value == 0, "Failure, params value is not int" def test_int_param_default(self): """Testing if int param defaults to provided default value.""" - assert params.IntParam("INT_DEFAULT_TEST2", default=456).value == 456, \ + assert params.IntParam("INT_DEFAULT_TEST2", default=456).value == 456, ( "Failure, params default value != 456" + ) def test_int_param_equality(self): """Test int equality.""" - assert (params.IntParam("INT_TEST1", default=123).equals(123).value - is True), "Failure, equality check returned False" - assert (params.IntParam("INT_TEST2", default=456).equals(123).value - is False), "Failure, equality check returned False" + assert params.IntParam("INT_TEST1", default=123).equals(123).value is True, ( + "Failure, equality check returned False" + ) + assert params.IntParam("INT_TEST2", default=456).equals(123).value is False, ( + "Failure, equality check returned False" + ) class TestStringParams: @@ -115,8 +124,9 @@ class TestStringParams: def test_string_param_value(self): """Testing if string param correctly returns a value.""" environ["STRING_VALUE_TEST"] = "STRING_TEST" - assert params.StringParam("STRING_VALUE_TEST").value == "STRING_TEST", \ + assert params.StringParam("STRING_VALUE_TEST").value == "STRING_TEST", ( 'Failure, params value != "STRING_TEST"' + ) def test_param_name_upper_snake_case(self): """Testing if param names are validated to be upper snake case.""" @@ -126,24 +136,25 @@ def test_param_name_upper_snake_case(self): def test_string_param_empty_default(self): """Testing if string param defaults to empty string if no value and no default.""" - assert params.StringParam("STRING_DEFAULT_TEST1").value == str(), \ + assert params.StringParam("STRING_DEFAULT_TEST1").value == "", ( "Failure, params value is not a string" + ) def test_string_param_default(self): """Testing if string param defaults to provided default value.""" - assert (params.StringParam("STRING_DEFAULT_TEST2", - default="string_override_default").value - == "string_override_default"), \ - 'Failure, params default value != "string_override_default"' + assert ( + params.StringParam("STRING_DEFAULT_TEST2", default="string_override_default").value + == "string_override_default" + ), 'Failure, params default value != "string_override_default"' def test_string_param_equality(self): """Test string equality.""" - assert (params.StringParam("STRING_TEST1", - default="123").equals("123").value - is True), "Failure, equality check returned False" - assert (params.StringParam("STRING_TEST2", - default="456").equals("123").value - is False), "Failure, equality check returned False" + assert params.StringParam("STRING_TEST1", default="123").equals("123").value is True, ( + "Failure, equality check returned False" + ) + assert params.StringParam("STRING_TEST2", default="456").equals("123").value is False, ( + "Failure, equality check returned False" + ) class TestListParams: @@ -152,34 +163,37 @@ class TestListParams: def test_list_param_value(self): """Testing if list param correctly returns list values.""" environ["LIST_VALUE_TEST1"] = "item1,item2" - assert params.ListParam("LIST_VALUE_TEST1").value == ["item1","item2"], \ + assert params.ListParam("LIST_VALUE_TEST1").value == ["item1", "item2"], ( 'Failure, params value != ["item1","item2"]' + ) def test_list_param_filter_empty_strings(self): """Testing if list param correctly returns list values wth empty strings excluded.""" environ["LIST_VALUE_TEST2"] = ",,item1,item2,,,item3," - assert params.ListParam("LIST_VALUE_TEST2").value == ["item1","item2", "item3"], \ + assert params.ListParam("LIST_VALUE_TEST2").value == ["item1", "item2", "item3"], ( 'Failure, params value != ["item1","item2", "item3"]' + ) def test_list_param_empty_default(self): """Testing if list param defaults to an empty list if no value and no default.""" - assert params.ListParam("LIST_DEFAULT_TEST1").value == [], \ + assert params.ListParam("LIST_DEFAULT_TEST1").value == [], ( "Failure, params value is not an empty list" + ) def test_list_param_default(self): """Testing if list param defaults to the provided default value.""" - assert (params.ListParam("LIST_DEFAULT_TEST2", default=["1", "2"]).value - == ["1", "2"]), \ + assert params.ListParam("LIST_DEFAULT_TEST2", default=["1", "2"]).value == ["1", "2"], ( 'Failure, params default value != ["1", "2"]' + ) def test_list_param_equality(self): """Test list equality.""" - assert (params.ListParam("LIST_TEST1", - default=["123"]).equals(["123"]).value - is True), "Failure, equality check returned False" - assert (params.ListParam("LIST_TEST2", - default=["456"]).equals(["123"]).value - is False), "Failure, equality check returned False" + assert params.ListParam("LIST_TEST1", default=["123"]).equals(["123"]).value is True, ( + "Failure, equality check returned False" + ) + assert params.ListParam("LIST_TEST2", default=["456"]).equals(["123"]).value is False, ( + "Failure, equality check returned False" + ) class TestParamsManifest: @@ -192,18 +206,18 @@ def test_params_stored(self): """Testing if params are internally stored.""" environ["TEST_STORING"] = "TEST_STORING_VALUE" param = params.StringParam("TEST_STORING") - assert param.value == "TEST_STORING_VALUE", \ - 'Failure, params value != "TEST_STORING_VALUE"' + assert param.value == "TEST_STORING_VALUE", 'Failure, params value != "TEST_STORING_VALUE"' # pylint: disable=protected-access - assert params._params["TEST_STORING"] == param, \ - "Failure, param was not stored" + assert params._params["TEST_STORING"] == param, "Failure, param was not stored" def test_default_params_not_stored(self): """Testing if default params are skipped from being stored.""" environ["GCLOUD_PROJECT"] = "python-testing-project" - assert params.PROJECT_ID.value == "python-testing-project", \ + assert params.PROJECT_ID.value == "python-testing-project", ( 'Failure, params value != "python-testing-project"' + ) # pylint: disable=protected-access - assert params._params.get("GCLOUD_PROJECT") is None, \ + assert params._params.get("GCLOUD_PROJECT") is None, ( "Failure, default param was stored when it should not have been" + ) diff --git a/tests/test_path_pattern.py b/tests/test_path_pattern.py index 3513e342..780a2a33 100644 --- a/tests/test_path_pattern.py +++ b/tests/test_path_pattern.py @@ -14,7 +14,8 @@ """Path Pattern unit tests.""" from unittest import TestCase -from firebase_functions.private.path_pattern import path_parts, PathPattern, trim_param + +from firebase_functions.private.path_pattern import PathPattern, path_parts, trim_param class TestPathUtilities(TestCase): @@ -77,8 +78,7 @@ def test_extract_matches(self): # parse multi segment with params after pp = PathPattern("something/**/else/{a}/hello/{b}/world") self.assertEqual( - pp.extract_matches( - "something/is/a/thing/else/nothing/hello/user/world"), + pp.extract_matches("something/is/a/thing/else/nothing/hello/user/world"), { "a": "nothing", "b": "user", @@ -88,8 +88,7 @@ def test_extract_matches(self): # parse multi-capture segment with params after pp = PathPattern("something/{path=**}/else/{a}/hello/{b}/world") self.assertEqual( - pp.extract_matches( - "something/is/a/thing/else/nothing/hello/user/world"), + pp.extract_matches("something/is/a/thing/else/nothing/hello/user/world"), { "path": "is/a/thing", "a": "nothing", @@ -100,8 +99,7 @@ def test_extract_matches(self): # parse multi segment with params before pp = PathPattern("{a}/something/{b}/**/end") self.assertEqual( - pp.extract_matches( - "match_a/something/match_b/thing/else/nothing/hello/user/end"), + pp.extract_matches("match_a/something/match_b/thing/else/nothing/hello/user/end"), { "a": "match_a", "b": "match_b", @@ -111,8 +109,7 @@ def test_extract_matches(self): # parse multi-capture segment with params before pp = PathPattern("{a}/something/{b}/{path=**}/end") self.assertEqual( - pp.extract_matches( - "match_a/something/match_b/thing/else/nothing/hello/user/end"), + pp.extract_matches("match_a/something/match_b/thing/else/nothing/hello/user/end"), { "a": "match_a", "b": "match_b", @@ -123,8 +120,7 @@ def test_extract_matches(self): # parse multi segment with params before and after pp = PathPattern("{a}/something/**/{b}/end") self.assertEqual( - pp.extract_matches( - "match_a/something/thing/else/nothing/hello/user/match_b/end"), + pp.extract_matches("match_a/something/thing/else/nothing/hello/user/match_b/end"), { "a": "match_a", "b": "match_b", @@ -134,8 +130,7 @@ def test_extract_matches(self): # parse multi-capture segment with params before and after pp = PathPattern("{a}/something/{path=**}/{b}/end") self.assertEqual( - pp.extract_matches( - "match_a/something/thing/else/nothing/hello/user/match_b/end"), + pp.extract_matches("match_a/something/thing/else/nothing/hello/user/match_b/end"), { "a": "match_a", "b": "match_b", diff --git a/tests/test_pubsub_fn.py b/tests/test_pubsub_fn.py index 74ae72ee..d0bf81b7 100644 --- a/tests/test_pubsub_fn.py +++ b/tests/test_pubsub_fn.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. """PubSub function tests.""" -import unittest + import datetime as _dt +import unittest from unittest.mock import MagicMock + from cloudevents.http import CloudEvent as _CloudEvent from firebase_functions import core from firebase_functions.pubsub_fn import ( + CloudEvent, Message, MessagePublishedData, - on_message_published, _message_handler, - CloudEvent, + on_message_published, ) @@ -40,12 +42,11 @@ def test_on_message_published_decorator(self): func = MagicMock() func.__name__ = "testfn" decorated_func = on_message_published(topic="hello-world")(func) - endpoint = getattr(decorated_func, "__firebase_endpoint__") + endpoint = decorated_func.__firebase_endpoint__ self.assertIsNotNone(endpoint) self.assertIsNotNone(endpoint.eventTrigger) self.assertIsNotNone(endpoint.eventTrigger["eventType"]) - self.assertEqual("hello-world", - endpoint.eventTrigger["eventFilters"]["topic"]) + self.assertEqual("hello-world", endpoint.eventTrigger["eventFilters"]["topic"]) def test_message_handler(self): """ @@ -64,9 +65,7 @@ def test_message_handler(self): }, data={ "message": { - "attributes": { - "key": "value" - }, + "attributes": {"key": "value"}, # {"test": "value"} "data": "eyJ0ZXN0IjogInZhbHVlIn0=", "message_id": "message-id-123", @@ -88,11 +87,10 @@ def test_message_handler(self): _dt.datetime.strptime( "2023-03-11T13:25:37.403Z", "%Y-%m-%dT%H:%M:%S.%f%z", - )) - self.assertDictEqual(event_arg.data.message.attributes, - {"key": "value"}) - self.assertEqual(event_arg.data.message.data, - "eyJ0ZXN0IjogInZhbHVlIn0=") + ), + ) + self.assertDictEqual(event_arg.data.message.attributes, {"key": "value"}) + self.assertEqual(event_arg.data.message.data, "eyJ0ZXN0IjogInZhbHVlIn0=") self.assertIsNone(event_arg.data.message.ordering_key) self.assertEqual(event_arg.data.subscription, "my-subscription") @@ -115,9 +113,7 @@ def init(): }, data={ "message": { - "attributes": { - "key": "value" - }, + "attributes": {"key": "value"}, "data": "eyJ0ZXN0IjogInZhbHVlIn0=", "message_id": "message-id-123", "publish_time": "2023-03-11T13:25:37.403Z", @@ -142,9 +138,7 @@ def test_datetime_without_mircroseconds_doesnt_throw(self): }, data={ "message": { - "attributes": { - "key": "value" - }, + "attributes": {"key": "value"}, "data": "eyJ0ZXN0IjogInZhbHVlIn0=", "message_id": "message-id-123", "publish_time": time, @@ -156,5 +150,4 @@ def test_datetime_without_mircroseconds_doesnt_throw(self): _message_handler(lambda _: None, raw_event) # pylint: disable=broad-except except Exception: - self.fail( - "Datetime without microseconds should not throw an exception") + self.fail("Datetime without microseconds should not throw an exception") diff --git a/tests/test_remote_config_fn.py b/tests/test_remote_config_fn.py index 2854a0a8..4ff7ebb5 100644 --- a/tests/test_remote_config_fn.py +++ b/tests/test_remote_config_fn.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. """Remote Config function tests.""" + import unittest from unittest.mock import MagicMock + from cloudevents.http import CloudEvent as _CloudEvent from firebase_functions.remote_config_fn import ( CloudEvent, - ConfigUser, ConfigUpdateData, ConfigUpdateOrigin, ConfigUpdateType, - on_config_updated, + ConfigUser, _config_handler, + on_config_updated, ) @@ -40,7 +42,7 @@ def test_on_config_updated_decorator(self): func = MagicMock() func.__name__ = "testfn" decorated_func = on_config_updated()(func) - endpoint = getattr(decorated_func, "__firebase_endpoint__") + endpoint = decorated_func.__firebase_endpoint__ self.assertIsNotNone(endpoint) self.assertIsNotNone(endpoint.eventTrigger) self.assertIsNotNone(endpoint.eventTrigger["eventType"]) @@ -66,13 +68,14 @@ def test_config_handler(self): "updateUser": { "name": "John Doe", "email": "johndoe@example.com", - "imageUrl": "https://example.com/image.jpg" + "imageUrl": "https://example.com/image.jpg", }, "description": "Test update", "updateOrigin": "CONSOLE", "updateType": "INCREMENTAL_UPDATE", - "rollbackSource": 41 - }) + "rollbackSource": 41, + }, + ) _config_handler(func, raw_event) @@ -83,8 +86,6 @@ def test_config_handler(self): self.assertIsInstance(event_arg.data, ConfigUpdateData) self.assertIsInstance(event_arg.data.update_user, ConfigUser) self.assertEqual(event_arg.data.version_number, 42) - self.assertEqual(event_arg.data.update_origin, - ConfigUpdateOrigin.CONSOLE) - self.assertEqual(event_arg.data.update_type, - ConfigUpdateType.INCREMENTAL_UPDATE) + self.assertEqual(event_arg.data.update_origin, ConfigUpdateOrigin.CONSOLE) + self.assertEqual(event_arg.data.update_type, ConfigUpdateType.INCREMENTAL_UPDATE) self.assertEqual(event_arg.data.rollback_source, 41) diff --git a/tests/test_scheduler_fn.py b/tests/test_scheduler_fn.py index 56853c65..b3a8c82f 100644 --- a/tests/test_scheduler_fn.py +++ b/tests/test_scheduler_fn.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Scheduler function tests.""" + import unittest -from unittest.mock import Mock from datetime import datetime -from flask import Request, Flask +from unittest.mock import Mock + +from flask import Flask, Request from werkzeug.test import EnvironBuilder -from firebase_functions import scheduler_fn, core + +from firebase_functions import core, scheduler_fn class TestScheduler(unittest.TestCase): @@ -35,9 +38,9 @@ def test_on_schedule_decorator(self): tz = "America/Los_Angeles" example_func = Mock(__name__="example_func") decorated_func = scheduler_fn.on_schedule( - schedule="* * * * *", - timezone=scheduler_fn.Timezone(tz))(example_func) - endpoint = getattr(decorated_func, "__firebase_endpoint__") + schedule="* * * * *", timezone=scheduler_fn.Timezone(tz) + )(example_func) + endpoint = decorated_func.__firebase_endpoint__ self.assertIsNotNone(endpoint) self.assertIsNotNone(endpoint.scheduleTrigger) @@ -55,12 +58,12 @@ def test_on_schedule_call(self): environ = EnvironBuilder( headers={ "X-CloudScheduler-JobName": "example-job", - "X-CloudScheduler-ScheduleTime": "2023-04-13T12:00:00-07:00" - }).get_environ() + "X-CloudScheduler-ScheduleTime": "2023-04-13T12:00:00-07:00", + } + ).get_environ() mock_request = Request(environ) example_func = Mock(__name__="example_func") - decorated_func = scheduler_fn.on_schedule( - schedule="* * * * *")(example_func) + decorated_func = scheduler_fn.on_schedule(schedule="* * * * *")(example_func) response = decorated_func(mock_request) self.assertEqual(response.status_code, 200) @@ -75,7 +78,8 @@ def test_on_schedule_call(self): 0, tzinfo=scheduler_fn.Timezone("America/Los_Angeles"), ), - )) + ) + ) def test_on_schedule_call_with_no_headers(self): """ @@ -88,8 +92,7 @@ def test_on_schedule_call_with_no_headers(self): environ = EnvironBuilder().get_environ() mock_request = Request(environ) example_func = Mock(__name__="example_func") - decorated_func = scheduler_fn.on_schedule( - schedule="* * * * *")(example_func) + decorated_func = scheduler_fn.on_schedule(schedule="* * * * *")(example_func) response = decorated_func(mock_request) self.assertEqual(response.status_code, 200) @@ -107,13 +110,12 @@ def test_on_schedule_call_with_exception(self): environ = EnvironBuilder( headers={ "X-CloudScheduler-JobName": "example-job", - "X-CloudScheduler-ScheduleTime": "2023-04-13T12:00:00-07:00" - }).get_environ() + "X-CloudScheduler-ScheduleTime": "2023-04-13T12:00:00-07:00", + } + ).get_environ() mock_request = Request(environ) - example_func = Mock(__name__="example_func", - side_effect=Exception("Test exception")) - decorated_func = scheduler_fn.on_schedule( - schedule="* * * * *")(example_func) + example_func = Mock(__name__="example_func", side_effect=Exception("Test exception")) + decorated_func = scheduler_fn.on_schedule(schedule="* * * * *")(example_func) response = decorated_func(mock_request) self.assertEqual(response.status_code, 500) @@ -131,8 +133,7 @@ def init(): environ = EnvironBuilder().get_environ() mock_request = Request(environ) example_func = Mock(__name__="example_func") - decorated_func = scheduler_fn.on_schedule( - schedule="* * * * *")(example_func) + decorated_func = scheduler_fn.on_schedule(schedule="* * * * *")(example_func) decorated_func(mock_request) self.assertEqual("world", hello) diff --git a/tests/test_storage_fn.py b/tests/test_storage_fn.py index ec55ca86..299c9e41 100644 --- a/tests/test_storage_fn.py +++ b/tests/test_storage_fn.py @@ -5,9 +5,10 @@ import unittest from unittest.mock import Mock -from firebase_functions import core, storage_fn from cloudevents.http import CloudEvent +from firebase_functions import core, storage_fn + class TestStorage(unittest.TestCase): """ @@ -23,19 +24,18 @@ def init(): hello = "world" func = Mock(__name__="example_func") - event = CloudEvent(attributes={ - "source": "source", - "type": "type" - }, - data={ - "bucket": "bucket", - "generation": "generation", - "id": "id", - "metageneration": "metageneration", - "name": "name", - "size": "size", - "storageClass": "storageClass", - }) + event = CloudEvent( + attributes={"source": "source", "type": "type"}, + data={ + "bucket": "bucket", + "generation": "generation", + "id": "id", + "metageneration": "metageneration", + "name": "name", + "size": "size", + "storageClass": "storageClass", + }, + ) decorated_func = storage_fn.on_object_archived(bucket="bucket")(func) decorated_func(event) diff --git a/tests/test_tasks_fn.py b/tests/test_tasks_fn.py index b16ede31..8c76678f 100644 --- a/tests/test_tasks_fn.py +++ b/tests/test_tasks_fn.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Task Queue function tests.""" -import unittest +import unittest from unittest.mock import MagicMock, Mock + from flask import Flask, Request from werkzeug.test import EnvironBuilder from firebase_functions import core -from firebase_functions.tasks_fn import on_task_dispatched, CallableRequest +from firebase_functions.tasks_fn import CallableRequest, on_task_dispatched class TestTasks(unittest.TestCase): @@ -36,7 +37,7 @@ def test_on_task_dispatched_decorator(self): func = MagicMock() func.__name__ = "testfn" decorated_func = on_task_dispatched()(func) - endpoint = getattr(decorated_func, "__firebase_endpoint__") + endpoint = decorated_func.__firebase_endpoint__ self.assertIsNotNone(endpoint) self.assertIsNotNone(endpoint.taskQueueTrigger) @@ -58,9 +59,7 @@ def example(request: CallableRequest[object]) -> str: environ = EnvironBuilder( method="POST", json={ - "data": { - "test": "value" - }, + "data": {"test": "value"}, }, ).get_environ() request = Request(environ) @@ -87,9 +86,7 @@ def init(): environ = EnvironBuilder( method="POST", json={ - "data": { - "test": "value" - }, + "data": {"test": "value"}, }, ).get_environ() request = Request(environ) diff --git a/tests/test_test_lab_fn.py b/tests/test_test_lab_fn.py index afa98364..7d7f6427 100644 --- a/tests/test_test_lab_fn.py +++ b/tests/test_test_lab_fn.py @@ -12,20 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test Lab function tests.""" + import unittest from unittest.mock import MagicMock, Mock + from cloudevents.http import CloudEvent as _CloudEvent from firebase_functions import core from firebase_functions.test_lab_fn import ( + ClientInfo, CloudEvent, - TestMatrixCompletedData, - TestState, OutcomeSummary, ResultStorage, - ClientInfo, - on_test_matrix_completed, + TestMatrixCompletedData, + TestState, _event_handler, + on_test_matrix_completed, ) @@ -42,7 +44,7 @@ def test_on_test_matrix_completed_decorator(self): func = MagicMock() func.__name__ = "testfn" decorated_func = on_test_matrix_completed()(func) - endpoint = getattr(decorated_func, "__firebase_endpoint__") + endpoint = decorated_func.__firebase_endpoint__ self.assertIsNotNone(endpoint) self.assertIsNotNone(endpoint.eventTrigger) self.assertIsNotNone(endpoint.eventTrigger["eventType"]) @@ -68,20 +70,17 @@ def test_event_handler(self): "invalidMatrixDetails": "Some details", "outcomeSummary": "SUCCESS", "resultStorage": { - "toolResultsHistory": - "projects/123/histories/456", - "resultsUri": - "https://example.com/results", - "gcsPath": - "gs://bucket/path/to/somewhere", - "toolResultsExecution": - "projects/123/histories/456/executions/789", + "toolResultsHistory": "projects/123/histories/456", + "resultsUri": "https://example.com/results", + "gcsPath": "gs://bucket/path/to/somewhere", + "toolResultsExecution": "projects/123/histories/456/executions/789", }, "clientInfo": { "client": "gcloud", }, "testMatrixId": "testmatrix-123", - }) + }, + ) _event_handler(func, raw_event) @@ -119,20 +118,17 @@ def init(): "invalidMatrixDetails": "Some details", "outcomeSummary": "SUCCESS", "resultStorage": { - "toolResultsHistory": - "projects/123/histories/456", - "resultsUri": - "https://example.com/results", - "gcsPath": - "gs://bucket/path/to/somewhere", - "toolResultsExecution": - "projects/123/histories/456/executions/789", + "toolResultsHistory": "projects/123/histories/456", + "resultsUri": "https://example.com/results", + "gcsPath": "gs://bucket/path/to/somewhere", + "toolResultsExecution": "projects/123/histories/456/executions/789", }, "clientInfo": { "client": "gcloud", }, "testMatrixId": "testmatrix-123", - }) + }, + ) decorated_func = on_test_matrix_completed()(func) decorated_func(raw_event) diff --git a/tests/test_util.py b/tests/test_util.py index cb13d309..34d975d2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -14,13 +14,24 @@ """ Internal utils tests. """ -from os import environ, path -from firebase_functions.private.util import firebase_config, microsecond_timestamp_conversion, nanoseconds_timestamp_conversion, get_precision_timestamp, normalize_path, deep_merge, PrecisionTimestamp, second_timestamp_conversion, _unsafe_decode_id_token + import datetime as _dt +from os import environ, path + +from firebase_functions.private.util import ( + PrecisionTimestamp, + _unsafe_decode_id_token, + deep_merge, + firebase_config, + get_precision_timestamp, + microsecond_timestamp_conversion, + nanoseconds_timestamp_conversion, + normalize_path, + second_timestamp_conversion, +) test_bucket = "python-functions-testing.appspot.com" -test_config_file = path.join(path.dirname(path.realpath(__file__)), - "firebase_config_test.json") +test_config_file = path.join(path.dirname(path.realpath(__file__)), "firebase_config_test.json") def test_firebase_config_loads_from_env_json(): @@ -30,7 +41,8 @@ def test_firebase_config_loads_from_env_json(): """ environ["FIREBASE_CONFIG"] = f'{{"storageBucket": "{test_bucket}"}}' assert firebase_config().storage_bucket == test_bucket, ( - "Failure, firebase_config did not load from env variable.") + "Failure, firebase_config did not load from env variable." + ) def test_firebase_config_loads_from_env_file(): @@ -40,7 +52,8 @@ def test_firebase_config_loads_from_env_file(): """ environ["FIREBASE_CONFIG"] = test_config_file assert firebase_config().storage_bucket == test_bucket, ( - "Failure, firebase_config did not load from env variable.") + "Failure, firebase_config did not load from env variable." + ) def test_microsecond_conversion(): @@ -55,11 +68,9 @@ def test_microsecond_conversion(): ] for input_timestamp, expected_output in timestamps: - expected_datetime = _dt.datetime.strptime(expected_output, - "%Y-%m-%dT%H:%M:%S.%fZ") + expected_datetime = _dt.datetime.strptime(expected_output, "%Y-%m-%dT%H:%M:%S.%fZ") expected_datetime = expected_datetime.replace(tzinfo=_dt.timezone.utc) - assert microsecond_timestamp_conversion( - input_timestamp) == expected_datetime + assert microsecond_timestamp_conversion(input_timestamp) == expected_datetime def test_nanosecond_conversion(): @@ -74,11 +85,9 @@ def test_nanosecond_conversion(): ] for input_timestamp, expected_output in timestamps: - expected_datetime = _dt.datetime.strptime(expected_output, - "%Y-%m-%dT%H:%M:%S.%fZ") + expected_datetime = _dt.datetime.strptime(expected_output, "%Y-%m-%dT%H:%M:%S.%fZ") expected_datetime = expected_datetime.replace(tzinfo=_dt.timezone.utc) - assert nanoseconds_timestamp_conversion( - input_timestamp) == expected_datetime + assert nanoseconds_timestamp_conversion(input_timestamp) == expected_datetime def test_second_conversion(): @@ -93,8 +102,7 @@ def test_second_conversion(): ] for input_timestamp, expected_output in timestamps: - expected_datetime = _dt.datetime.strptime(expected_output, - "%Y-%m-%dT%H:%M:%SZ") + expected_datetime = _dt.datetime.strptime(expected_output, "%Y-%m-%dT%H:%M:%SZ") expected_datetime = expected_datetime.replace(tzinfo=_dt.timezone.utc) assert second_timestamp_conversion(input_timestamp) == expected_datetime @@ -118,30 +126,18 @@ def test_is_nanoseconds_timestamp(): second_timestamp3 = "2023-03-21T06:43:58Z" second_timestamp4 = "2023-08-15T22:22:22Z" - assert get_precision_timestamp( - microsecond_timestamp1) is PrecisionTimestamp.MICROSECONDS - assert get_precision_timestamp( - microsecond_timestamp2) is PrecisionTimestamp.MICROSECONDS - assert get_precision_timestamp( - microsecond_timestamp3) is PrecisionTimestamp.MICROSECONDS - assert get_precision_timestamp( - microsecond_timestamp4) is PrecisionTimestamp.MICROSECONDS - assert get_precision_timestamp( - nanosecond_timestamp1) is PrecisionTimestamp.NANOSECONDS - assert get_precision_timestamp( - nanosecond_timestamp2) is PrecisionTimestamp.NANOSECONDS - assert get_precision_timestamp( - nanosecond_timestamp3) is PrecisionTimestamp.NANOSECONDS - assert get_precision_timestamp( - nanosecond_timestamp4) is PrecisionTimestamp.NANOSECONDS - assert get_precision_timestamp( - second_timestamp1) is PrecisionTimestamp.SECONDS - assert get_precision_timestamp( - second_timestamp2) is PrecisionTimestamp.SECONDS - assert get_precision_timestamp( - second_timestamp3) is PrecisionTimestamp.SECONDS - assert get_precision_timestamp( - second_timestamp4) is PrecisionTimestamp.SECONDS + assert get_precision_timestamp(microsecond_timestamp1) is PrecisionTimestamp.MICROSECONDS + assert get_precision_timestamp(microsecond_timestamp2) is PrecisionTimestamp.MICROSECONDS + assert get_precision_timestamp(microsecond_timestamp3) is PrecisionTimestamp.MICROSECONDS + assert get_precision_timestamp(microsecond_timestamp4) is PrecisionTimestamp.MICROSECONDS + assert get_precision_timestamp(nanosecond_timestamp1) is PrecisionTimestamp.NANOSECONDS + assert get_precision_timestamp(nanosecond_timestamp2) is PrecisionTimestamp.NANOSECONDS + assert get_precision_timestamp(nanosecond_timestamp3) is PrecisionTimestamp.NANOSECONDS + assert get_precision_timestamp(nanosecond_timestamp4) is PrecisionTimestamp.NANOSECONDS + assert get_precision_timestamp(second_timestamp1) is PrecisionTimestamp.SECONDS + assert get_precision_timestamp(second_timestamp2) is PrecisionTimestamp.SECONDS + assert get_precision_timestamp(second_timestamp3) is PrecisionTimestamp.SECONDS + assert get_precision_timestamp(second_timestamp4) is PrecisionTimestamp.SECONDS def test_normalize_document_path(): @@ -150,16 +146,15 @@ def test_normalize_document_path(): is normalized. """ test_path = "/test/document/" - assert normalize_path(test_path) == "test/document", ( - "Failure, path was not normalized.") + assert normalize_path(test_path) == "test/document", "Failure, path was not normalized." test_path1 = "//////test/document//////////" - assert normalize_path(test_path1) == "test/document", ( - "Failure, path was not normalized.") + assert normalize_path(test_path1) == "test/document", "Failure, path was not normalized." test_path2 = "test/document" assert normalize_path(test_path2) == "test/document", ( - "Failure, path should not be changed if it is already normalized.") + "Failure, path should not be changed if it is already normalized." + ) def test_toplevel_keys():