diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0444d3c..a3cb45d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
 and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
 
+## [Unreleased]
+### Added
+- Allow to define a custom user management via the `after_logged_in` method #156
+
+### Changed
+- Updated the `public_callback` to work in more cases #163
+
 ## [2.3.0] - 2024-03-18
 ### Added
 - OIDCAuth allows to authenticate via OIDC
diff --git a/dash_auth/public_routes.py b/dash_auth/public_routes.py
index 5c9540c..dfe6067 100644
--- a/dash_auth/public_routes.py
+++ b/dash_auth/public_routes.py
@@ -1,9 +1,10 @@
-import inspect
+import logging
 import os
 
-from dash import Dash, callback
-from dash._callback import GLOBAL_CALLBACK_MAP
-from dash import get_app
+from dash import Dash, Output, callback, get_app
+from dash._callback import handle_grouped_callback_args
+from dash._grouping import flatten_grouping
+from dash._utils import create_callback_id
 from werkzeug.routing import Map, MapAdapter, Rule
 
 
@@ -70,12 +71,20 @@ def public_callback(*callback_args, **callback_kwargs):
     def decorator(func):
 
         wrapped_func = callback(*callback_args, **callback_kwargs)(func)
-        callback_id = next(
-            (
-                k for k, v in GLOBAL_CALLBACK_MAP.items()
-                if inspect.getsource(v["callback"]) == inspect.getsource(func)
-            ),
-            None,
+        output, inputs, _, _, _ = handle_grouped_callback_args(
+            callback_args, callback_kwargs
+        )
+        if isinstance(output, Output):
+            # Insert callback with scalar (non-multi) Output
+            output = output
+            has_output = True
+        else:
+            # Insert callback as multi Output
+            output = flatten_grouping(output)
+            has_output = len(output) > 0
+
+        callback_id = create_callback_id(
+            output, inputs, no_output=not has_output
         )
         try:
             app = get_app()
@@ -83,7 +92,7 @@ def decorator(func):
                 get_public_callbacks(app) + [callback_id]
             )
         except Exception:
-            print(
+            logging.info(
                 "Could not set up the public callback as the Dash object "
                 "has not yet been instantiated."
             )
diff --git a/tests/test_oidc_auth.py b/tests/test_oidc_auth.py
index 5442a67..d5b24bd 100644
--- a/tests/test_oidc_auth.py
+++ b/tests/test_oidc_auth.py
@@ -1,14 +1,10 @@
-import os
 from unittest.mock import patch
 
 import requests
 from dash import Dash, Input, Output, dcc, html
 from flask import redirect
 
-from dash_auth import (
-    protected_callback,
-    OIDCAuth,
-)
+from dash_auth import OIDCAuth, protected_callback
 
 
 def valid_authorize_redirect(_, redirect_uri, *args, **kwargs):
@@ -17,7 +13,9 @@ def valid_authorize_redirect(_, redirect_uri, *args, **kwargs):
 
 def invalid_authorize_redirect(_, redirect_uri, *args, **kwargs):
     base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1]
-    return redirect(f"{base_url}?error=Unauthorized&error_description=something went wrong")
+    return redirect(
+        f"{base_url}?error=Unauthorized&error_description=something went wrong"
+    )
 
 
 def valid_authorize_access_token(*args, **kwargs):
@@ -27,18 +25,26 @@ def valid_authorize_access_token(*args, **kwargs):
     }
 
 
-@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect)
-@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token)
+@patch(
+    "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
+    valid_authorize_redirect,
+)
+@patch(
+    "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token",
+    valid_authorize_access_token,
+)
 def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server):
     app = Dash(__name__)
-    app.layout = html.Div([
-        dcc.Input(id="input", value="initial value"),
-        html.Div(id="output1"),
-        html.Div(id="output2"),
-        html.Div("static", id="output3"),
-        html.Div("static", id="output4"),
-        html.Div("not static", id="output5"),
-    ])
+    app.layout = html.Div(
+        [
+            dcc.Input(id="input", value="initial value"),
+            html.Div(id="output1"),
+            html.Div(id="output2"),
+            html.Div("static", id="output3"),
+            html.Div("static", id="output4"),
+            html.Div("not static", id="output5"),
+        ]
+    )
 
     @app.callback(Output("output1", "children"), Input("input", "value"))
     def update_output1(new_value):
@@ -101,13 +107,15 @@ def update_output5(new_value):
     dash_br.wait_for_text_to_equal("#output5", "initial value")
 
 
-@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", invalid_authorize_redirect)
+@patch(
+    "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
+    invalid_authorize_redirect,
+)
 def test_oa002_oidc_auth_login_fail(dash_thread_server):
     app = Dash(__name__)
-    app.layout = html.Div([
-        dcc.Input(id="input", value="initial value"),
-        html.Div(id="output")
-    ])
+    app.layout = html.Div(
+        [dcc.Input(id="input", value="initial value"), html.Div(id="output")]
+    )
 
     @app.callback(Output("output", "children"), Input("input", "value"))
     def update_output(new_value):
@@ -122,7 +130,7 @@ def update_output(new_value):
         server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration",
     )
     dash_thread_server(app)
-    base_url = dash_thread_server.url
+    base_url = dash_thread_server.url.rstrip("/")
 
     def test_unauthorized(url):
         r = requests.get(url)
@@ -133,17 +141,25 @@ def test_authorized(url):
         assert requests.get(url).status_code == 200
 
     test_unauthorized(base_url)
-    test_authorized(os.path.join(base_url, "public"))
+    test_authorized(base_url + "/public")
 
 
-@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect)
-@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token)
+@patch(
+    "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
+    valid_authorize_redirect,
+)
+@patch(
+    "authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token",
+    valid_authorize_access_token,
+)
 def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server):
     app = Dash(__name__)
-    app.layout = html.Div([
-        dcc.Input(id="input", value="initial value"),
-        html.Div(id="output1"),
-    ])
+    app.layout = html.Div(
+        [
+            dcc.Input(id="input", value="initial value"),
+            html.Div(id="output1"),
+        ]
+    )
 
     @app.callback(Output("output1", "children"), Input("input", "value"))
     def update_output1(new_value):
@@ -168,21 +184,21 @@ def update_output1(new_value):
     )
 
     dash_thread_server(app)
-    base_url = dash_thread_server.url
+    base_url = dash_thread_server.url.rstrip("/")
 
     assert requests.get(base_url).status_code == 400
 
     # Login with IDP1
-    assert requests.get(os.path.join(base_url, "oidc/idp1/login")).status_code == 200
+    assert requests.get(base_url + "/oidc/idp1/login").status_code == 200
 
     # Logout
-    assert requests.get(os.path.join(base_url, "oidc/logout")).status_code == 200
+    assert requests.get(base_url + "/oidc/logout").status_code == 200
 
     assert requests.get(base_url).status_code == 400
 
     # Login with IDP2
-    assert requests.get(os.path.join(base_url, "oidc/idp2/login")).status_code == 200
+    assert requests.get(base_url + "/oidc/idp2/login").status_code == 200
 
-    dash_br.driver.get(os.path.join(base_url, "oidc/idp2/login"))
+    dash_br.driver.get(base_url + "/oidc/idp2/login")
     dash_br.driver.get(base_url)
-    dash_br.wait_for_text_to_equal("#output1", "initial value")
+    dash_br.wait_for_text_to_equal("#output1", "initial value")
\ No newline at end of file