diff --git a/CHANGELOG.md b/CHANGELOG.md index 0444d3c..a3cb45d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Added +- Allow to define a custom user management via the `after_logged_in` method #156 + +### Changed +- Updated the `public_callback` to work in more cases #163 + ## [2.3.0] - 2024-03-18 ### Added - OIDCAuth allows to authenticate via OIDC diff --git a/dash_auth/public_routes.py b/dash_auth/public_routes.py index 5c9540c..dfe6067 100644 --- a/dash_auth/public_routes.py +++ b/dash_auth/public_routes.py @@ -1,9 +1,10 @@ -import inspect +import logging import os -from dash import Dash, callback -from dash._callback import GLOBAL_CALLBACK_MAP -from dash import get_app +from dash import Dash, Output, callback, get_app +from dash._callback import handle_grouped_callback_args +from dash._grouping import flatten_grouping +from dash._utils import create_callback_id from werkzeug.routing import Map, MapAdapter, Rule @@ -70,12 +71,20 @@ def public_callback(*callback_args, **callback_kwargs): def decorator(func): wrapped_func = callback(*callback_args, **callback_kwargs)(func) - callback_id = next( - ( - k for k, v in GLOBAL_CALLBACK_MAP.items() - if inspect.getsource(v["callback"]) == inspect.getsource(func) - ), - None, + output, inputs, _, _, _ = handle_grouped_callback_args( + callback_args, callback_kwargs + ) + if isinstance(output, Output): + # Insert callback with scalar (non-multi) Output + output = output + has_output = True + else: + # Insert callback as multi Output + output = flatten_grouping(output) + has_output = len(output) > 0 + + callback_id = create_callback_id( + output, inputs, no_output=not has_output ) try: app = get_app() @@ -83,7 +92,7 @@ def decorator(func): get_public_callbacks(app) + [callback_id] ) except Exception: - print( + logging.info( "Could not set up the public callback as the Dash object " "has not yet been instantiated." ) diff --git a/tests/test_oidc_auth.py b/tests/test_oidc_auth.py index 5442a67..d5b24bd 100644 --- a/tests/test_oidc_auth.py +++ b/tests/test_oidc_auth.py @@ -1,14 +1,10 @@ -import os from unittest.mock import patch import requests from dash import Dash, Input, Output, dcc, html from flask import redirect -from dash_auth import ( - protected_callback, - OIDCAuth, -) +from dash_auth import OIDCAuth, protected_callback def valid_authorize_redirect(_, redirect_uri, *args, **kwargs): @@ -17,7 +13,9 @@ def valid_authorize_redirect(_, redirect_uri, *args, **kwargs): def invalid_authorize_redirect(_, redirect_uri, *args, **kwargs): base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1] - return redirect(f"{base_url}?error=Unauthorized&error_description=something went wrong") + return redirect( + f"{base_url}?error=Unauthorized&error_description=something went wrong" + ) def valid_authorize_access_token(*args, **kwargs): @@ -27,18 +25,26 @@ def valid_authorize_access_token(*args, **kwargs): } -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect) -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", + valid_authorize_redirect, +) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", + valid_authorize_access_token, +) def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server): app = Dash(__name__) - app.layout = html.Div([ - dcc.Input(id="input", value="initial value"), - html.Div(id="output1"), - html.Div(id="output2"), - html.Div("static", id="output3"), - html.Div("static", id="output4"), - html.Div("not static", id="output5"), - ]) + app.layout = html.Div( + [ + dcc.Input(id="input", value="initial value"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div("static", id="output3"), + html.Div("static", id="output4"), + html.Div("not static", id="output5"), + ] + ) @app.callback(Output("output1", "children"), Input("input", "value")) def update_output1(new_value): @@ -101,13 +107,15 @@ def update_output5(new_value): dash_br.wait_for_text_to_equal("#output5", "initial value") -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", invalid_authorize_redirect) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", + invalid_authorize_redirect, +) def test_oa002_oidc_auth_login_fail(dash_thread_server): app = Dash(__name__) - app.layout = html.Div([ - dcc.Input(id="input", value="initial value"), - html.Div(id="output") - ]) + app.layout = html.Div( + [dcc.Input(id="input", value="initial value"), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("input", "value")) def update_output(new_value): @@ -122,7 +130,7 @@ def update_output(new_value): server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration", ) dash_thread_server(app) - base_url = dash_thread_server.url + base_url = dash_thread_server.url.rstrip("/") def test_unauthorized(url): r = requests.get(url) @@ -133,17 +141,25 @@ def test_authorized(url): assert requests.get(url).status_code == 200 test_unauthorized(base_url) - test_authorized(os.path.join(base_url, "public")) + test_authorized(base_url + "/public") -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect) -@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", + valid_authorize_redirect, +) +@patch( + "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", + valid_authorize_access_token, +) def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server): app = Dash(__name__) - app.layout = html.Div([ - dcc.Input(id="input", value="initial value"), - html.Div(id="output1"), - ]) + app.layout = html.Div( + [ + dcc.Input(id="input", value="initial value"), + html.Div(id="output1"), + ] + ) @app.callback(Output("output1", "children"), Input("input", "value")) def update_output1(new_value): @@ -168,21 +184,21 @@ def update_output1(new_value): ) dash_thread_server(app) - base_url = dash_thread_server.url + base_url = dash_thread_server.url.rstrip("/") assert requests.get(base_url).status_code == 400 # Login with IDP1 - assert requests.get(os.path.join(base_url, "oidc/idp1/login")).status_code == 200 + assert requests.get(base_url + "/oidc/idp1/login").status_code == 200 # Logout - assert requests.get(os.path.join(base_url, "oidc/logout")).status_code == 200 + assert requests.get(base_url + "/oidc/logout").status_code == 200 assert requests.get(base_url).status_code == 400 # Login with IDP2 - assert requests.get(os.path.join(base_url, "oidc/idp2/login")).status_code == 200 + assert requests.get(base_url + "/oidc/idp2/login").status_code == 200 - dash_br.driver.get(os.path.join(base_url, "oidc/idp2/login")) + dash_br.driver.get(base_url + "/oidc/idp2/login") dash_br.driver.get(base_url) - dash_br.wait_for_text_to_equal("#output1", "initial value") + dash_br.wait_for_text_to_equal("#output1", "initial value") \ No newline at end of file