diff --git a/dash_auth/public_routes.py b/dash_auth/public_routes.py index 5c9540c..b41c520 100644 --- a/dash_auth/public_routes.py +++ b/dash_auth/public_routes.py @@ -1,12 +1,10 @@ import inspect import os -from dash import Dash, callback +from dash import Dash, callback, get_app from dash._callback import GLOBAL_CALLBACK_MAP -from dash import get_app from werkzeug.routing import Map, MapAdapter, Rule - DASH_PUBLIC_ASSETS_EXTENSIONS = "js,css" BASE_PUBLIC_ROUTES = [ f"/assets/.{ext}" @@ -68,20 +66,21 @@ def public_callback(*callback_args, **callback_kwargs): """ def decorator(func): - wrapped_func = callback(*callback_args, **callback_kwargs)(func) callback_id = next( ( - k for k, v in GLOBAL_CALLBACK_MAP.items() - if inspect.getsource(v["callback"]) == inspect.getsource(func) + k + for k, v in GLOBAL_CALLBACK_MAP.items() + if "callback" in v + and inspect.getsource(v["callback"]) == inspect.getsource(func) ), None, ) try: app = get_app() - app.server.config[PUBLIC_CALLBACKS] = ( - get_public_callbacks(app) + [callback_id] - ) + app.server.config[PUBLIC_CALLBACKS] = get_public_callbacks(app) + [ + callback_id + ] except Exception: print( "Could not set up the public callback as the Dash object " diff --git a/tests/test_oidc_auth.py b/tests/test_oidc_auth.py index 5442a67..ed1a8da 100644 --- a/tests/test_oidc_auth.py +++ b/tests/test_oidc_auth.py @@ -1,14 +1,10 @@ -import os from unittest.mock import patch import requests from dash import Dash, Input, Output, dcc, html from flask import redirect -from dash_auth import ( - protected_callback, - OIDCAuth, -) +from dash_auth import OIDCAuth, protected_callback def valid_authorize_redirect(_, redirect_uri, *args, **kwargs): @@ -17,7 +13,9 @@ def valid_authorize_redirect(_, redirect_uri, *args, **kwargs): def invalid_authorize_redirect(_, redirect_uri, *args, **kwargs): base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1] - return redirect(f"{base_url}?error=Unauthorized&error_description=something went wrong") + return redirect( + f"{base_url}?error=Unauthorized&error_description=something went wrong" + ) def valid_authorize_access_token(*args, **kwargs): @@ -27,18 +25,26 @@ def valid_authorize_access_token(*args, **kwargs): } -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect) -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", + valid_authorize_redirect, +) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", + valid_authorize_access_token, +) def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server): app = Dash(__name__) - app.layout = html.Div([ - dcc.Input(id="input", value="initial value"), - html.Div(id="output1"), - html.Div(id="output2"), - html.Div("static", id="output3"), - html.Div("static", id="output4"), - html.Div("not static", id="output5"), - ]) + app.layout = html.Div( + [ + dcc.Input(id="input", value="initial value"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div("static", id="output3"), + html.Div("static", id="output4"), + html.Div("not static", id="output5"), + ] + ) @app.callback(Output("output1", "children"), Input("input", "value")) def update_output1(new_value): @@ -101,13 +107,15 @@ def update_output5(new_value): dash_br.wait_for_text_to_equal("#output5", "initial value") -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", invalid_authorize_redirect) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", + invalid_authorize_redirect, +) def test_oa002_oidc_auth_login_fail(dash_thread_server): app = Dash(__name__) - app.layout = html.Div([ - dcc.Input(id="input", value="initial value"), - html.Div(id="output") - ]) + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("input", "value")) def update_output(new_value): @@ -122,7 +130,7 @@ def update_output(new_value): server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration", ) dash_thread_server(app) - base_url = dash_thread_server.url + base_url = dash_thread_server.url.rstrip("/") def test_unauthorized(url): r = requests.get(url) @@ -133,17 +141,25 @@ def test_authorized(url): assert requests.get(url).status_code == 200 test_unauthorized(base_url) - test_authorized(os.path.join(base_url, "public")) + test_authorized(base_url + "/public") -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect) -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", + valid_authorize_redirect, +) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", + valid_authorize_access_token, +) def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server): app = Dash(__name__) - app.layout = html.Div([ - dcc.Input(id="input", value="initial value"), - html.Div(id="output1"), - ]) + app.layout = html.Div( + [ + dcc.Input(id="input", value="initial value"), + html.Div(id="output1"), + ] + ) @app.callback(Output("output1", "children"), Input("input", "value")) def update_output1(new_value): @@ -168,21 +184,21 @@ def update_output1(new_value): ) dash_thread_server(app) - base_url = dash_thread_server.url + base_url = dash_thread_server.url.rstrip("/") assert requests.get(base_url).status_code == 400 # Login with IDP1 - assert requests.get(os.path.join(base_url, "oidc/idp1/login")).status_code == 200 + assert requests.get(base_url + "/oidc/idp1/login").status_code == 200 # Logout - assert requests.get(os.path.join(base_url, "oidc/logout")).status_code == 200 + assert requests.get(base_url + "/oidc/logout").status_code == 200 assert requests.get(base_url).status_code == 400 # Login with IDP2 - assert requests.get(os.path.join(base_url, "oidc/idp2/login")).status_code == 200 + assert requests.get(base_url + "/oidc/idp2/login").status_code == 200 - dash_br.driver.get(os.path.join(base_url, "oidc/idp2/login")) + dash_br.driver.get(base_url + "/oidc/idp2/login") dash_br.driver.get(base_url) dash_br.wait_for_text_to_equal("#output1", "initial value")