Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: split tasks https #201

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions src/firebase_functions/https_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,8 @@ class CallableRequest(_typing.Generic[_core.T]):
_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]


def _on_call_handler(func: _C2,
request: Request,
enforce_app_check: bool,
verify_token: bool = True) -> Response:
def _on_call_handler(func: _C2, request: Request,
enforce_app_check: bool) -> Response:
try:
if not _util.valid_on_call_request(request):
_logging.error("Invalid request, unable to process.")
Expand All @@ -365,8 +363,7 @@ def _on_call_handler(func: _C2,
data=_json.loads(request.data)["data"],
)

token_status = _util.on_call_check_tokens(request,
verify_token=verify_token)
token_status = _util.on_call_check_tokens(request)

if token_status.auth == _util.OnCallTokenState.INVALID:
raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED,
Expand Down Expand Up @@ -420,7 +417,7 @@ def _on_call_handler(func: _C2,
def on_request(**kwargs) -> _typing.Callable[[_C1], _C1]:
"""
Handler which handles HTTPS requests.
Requires a function that takes a ``Request`` and ``Response`` object,
Requires a function that takes a ``Request`` and ``Response`` object,
the same signature as a Flask app.

Example:
Expand Down
30 changes: 12 additions & 18 deletions src/firebase_functions/private/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,10 @@ def as_dict(self) -> dict:


def _on_call_check_auth_token(
request: _Request,
verify_token: bool = True,
request: _Request
) -> None | _typing.Literal[OnCallTokenState.INVALID] | dict[str, _typing.Any]:
"""
Validates the auth token in a callable request.
Validates the auth token in a callable request.
If verify_token is False, the token will be decoded without verification.
"""
authorization = request.headers.get("Authorization")
Expand All @@ -227,10 +226,7 @@ def _on_call_check_auth_token(
return OnCallTokenState.INVALID
try:
id_token = authorization.replace("Bearer ", "")
if verify_token:
auth_token = _auth.verify_id_token(id_token)
else:
auth_token = _unsafe_decode_id_token(id_token)
auth_token = _auth.verify_id_token(id_token)
return auth_token
# pylint: disable=broad-except
except Exception as err:
Expand Down Expand Up @@ -273,25 +269,23 @@ def _unsafe_decode_id_token(token: str):
return payload


def on_call_check_tokens(request: _Request,
verify_token: bool = True) -> _OnCallTokenVerification:
def on_call_check_tokens(request: _Request) -> _OnCallTokenVerification:
"""Check tokens"""
verifications = _OnCallTokenVerification()

auth_token = _on_call_check_auth_token(request, verify_token=verify_token)
auth_token = _on_call_check_auth_token(request)
if auth_token is None:
verifications.auth = OnCallTokenState.MISSING
elif isinstance(auth_token, dict):
verifications.auth = OnCallTokenState.VALID
verifications.auth_token = auth_token

if verify_token:
app_token = _on_call_check_app_token(request)
if app_token is None:
verifications.app = OnCallTokenState.MISSING
elif isinstance(app_token, dict):
verifications.app = OnCallTokenState.VALID
verifications.app_token = app_token
app_token = _on_call_check_app_token(request)
if app_token is None:
verifications.app = OnCallTokenState.MISSING
elif isinstance(app_token, dict):
verifications.app = OnCallTokenState.VALID
verifications.app_token = app_token

log_payload = {
**verifications.as_dict(),
Expand All @@ -301,7 +295,7 @@ def on_call_check_tokens(request: _Request,
}

errs = []
if verify_token and verifications.app == OnCallTokenState.INVALID:
if verifications.app == OnCallTokenState.INVALID:
errs.append(("AppCheck token was rejected.", log_payload))

if verifications.auth == OnCallTokenState.INVALID:
Expand Down
50 changes: 44 additions & 6 deletions src/firebase_functions/tasks_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,55 @@
# pylint: disable=protected-access
import typing as _typing
import functools as _functools
import dataclasses as _dataclasses
import json as _json

from flask import Request, Response
from flask import Request, Response, make_response as _make_response, jsonify as _jsonify

import firebase_functions.core as _core
import firebase_functions.options as _options
import firebase_functions.private.util as _util
from firebase_functions.https_fn import CallableRequest, _on_call_handler
from firebase_functions.https_fn import CallableRequest, HttpsError, FunctionsErrorCode

from functions_framework import logging as _logging

_C = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]
_C1 = _typing.Callable[[Request], Response]
_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]


def _on_call_handler(func: _C2, request: Request) -> Response:
try:
if not _util.valid_on_call_request(request):
_logging.error("Invalid request, unable to process.")
raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request")
context: CallableRequest = CallableRequest(
raw_request=request,
data=_json.loads(request.data)["data"],
)

instance_id = request.headers.get("Firebase-Instance-ID-Token")
if instance_id is not None:
# Validating the token requires an http request, so we don't do it.
# If the user wants to use it for something, it will be validated then.
# Currently, the only real use case for this token is for sending
# pushes with FCM. In that case, the FCM APIs will validate the token.
context = _dataclasses.replace(
context,
instance_id_token=request.headers.get(
"Firebase-Instance-ID-Token"),
)
result = _core._with_init(func)(context)
return _jsonify(result=result)
# Disable broad exceptions lint since we want to handle all exceptions here
# and wrap as an HttpsError.
# pylint: disable=broad-except
except Exception as err:
if not isinstance(err, HttpsError):
_logging.error("Unhandled error: %s", err)
err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
status = err._http_error_code.status
return _make_response(_jsonify(error=err._as_dict()), status)


@_util.copy_func_kwargs(_options.TaskQueueOptions)
Expand Down Expand Up @@ -53,10 +94,7 @@ def on_task_dispatched_decorator(func: _C):

@_functools.wraps(func)
def on_task_dispatched_wrapped(request: Request) -> Response:
return _on_call_handler(func,
request,
enforce_app_check=False,
verify_token=False)
return _on_call_handler(func, request)

_util.set_func_endpoint_attr(
on_task_dispatched_wrapped,
Expand Down
35 changes: 0 additions & 35 deletions tests/test_tasks_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,41 +71,6 @@ def example(request: CallableRequest[object]) -> str:
'{"result":"Hello World"}\n',
)

def test_token_is_decoded(self):
"""
Test that the token is decoded instead of verifying auth first.
"""
app = Flask(__name__)

@on_task_dispatched()
def example(request: CallableRequest[object]) -> str:
auth = request.auth
# Make mypy happy
if auth is None:
self.fail("Auth is None")
return "No Auth"
self.assertEqual(auth.token["sub"], "firebase")
self.assertEqual(auth.token["name"], "John Doe")
return "Hello World"

with app.test_request_context("/"):
# pylint: disable=line-too-long
test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0"
environ = EnvironBuilder(
method="POST",
headers={
"Authorization": f"Bearer {test_token}"
},
json={
"data": {
"test": "value"
},
},
).get_environ()
request = Request(environ)
response = example(request)
self.assertEqual(response.status_code, 200)

def test_calls_init(self):
hello = None

Expand Down
Loading