-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathtest_oidc_auth.py
204 lines (168 loc) · 6.13 KB
/
test_oidc_auth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from unittest.mock import patch
import requests
from dash import Dash, Input, Output, dcc, html
from flask import redirect
from dash_auth import OIDCAuth, protected_callback
def valid_authorize_redirect(_, redirect_uri, *args, **kwargs):
return redirect("/" + redirect_uri.split("/", maxsplit=3)[-1])
def invalid_authorize_redirect(_, redirect_uri, *args, **kwargs):
base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1]
return redirect(
f"{base_url}?error=Unauthorized&error_description=something went wrong"
)
def valid_authorize_access_token(*args, **kwargs):
return {
"userinfo": {"email": "[email protected]", "groups": ["viewer", "editor"]},
"refresh_token": "ABCDEF",
}
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
valid_authorize_redirect,
)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token",
valid_authorize_access_token,
)
def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server):
app = Dash(__name__)
app.layout = html.Div(
[
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
html.Div(id="output2"),
html.Div("static", id="output3"),
html.Div("static", id="output4"),
html.Div("not static", id="output5"),
]
)
@app.callback(Output("output1", "children"), Input("input", "value"))
def update_output1(new_value):
return new_value
@protected_callback(
Output("output2", "children"),
Input("input", "value"),
groups=["editor"],
check_type="one_of",
)
def update_output2(new_value):
return new_value
@protected_callback(
Output("output3", "children"),
Input("input", "value"),
groups=["admin"],
check_type="one_of",
)
def update_output3(new_value):
return new_value
@protected_callback(
Output("output4", "children"),
Input("input", "value"),
groups=["viewer"],
check_type="none_of",
)
def update_output4(new_value):
return new_value
@protected_callback(
Output("output5", "children"),
Input("input", "value"),
groups=["viewer", "editor"],
check_type="all_of",
)
def update_output5(new_value):
return new_value
oidc = OIDCAuth(app, secret_key="Test")
oidc.register_provider(
"oidc",
token_endpoint_auth_method="client_secret_post",
client_id="<client-id>",
client_secret="<client-secret>",
server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration",
)
dash_thread_server(app)
base_url = dash_thread_server.url
assert requests.get(base_url).status_code == 200
dash_br.driver.get(base_url)
dash_br.wait_for_text_to_equal("#output1", "initial value")
dash_br.wait_for_text_to_equal("#output2", "initial value")
dash_br.wait_for_text_to_equal("#output3", "static")
dash_br.wait_for_text_to_equal("#output4", "static")
dash_br.wait_for_text_to_equal("#output5", "initial value")
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
invalid_authorize_redirect,
)
def test_oa002_oidc_auth_login_fail(dash_thread_server):
app = Dash(__name__)
app.layout = html.Div(
[dcc.Input(id="input", value="initial value"), html.Div(id="output")]
)
@app.callback(Output("output", "children"), Input("input", "value"))
def update_output(new_value):
return new_value
oidc = OIDCAuth(app, public_routes=["/public"], secret_key="Test")
oidc.register_provider(
"oidc",
token_endpoint_auth_method="client_secret_post",
client_id="<client-id>",
client_secret="<client-secret>",
server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration",
)
dash_thread_server(app)
base_url = dash_thread_server.url.rstrip("/")
def test_unauthorized(url):
r = requests.get(url)
assert r.status_code == 401
assert r.text == "Unauthorized: something went wrong"
def test_authorized(url):
assert requests.get(url).status_code == 200
test_unauthorized(base_url)
test_authorized(base_url + "/public")
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
valid_authorize_redirect,
)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token",
valid_authorize_access_token,
)
def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server):
app = Dash(__name__)
app.layout = html.Div(
[
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
]
)
@app.callback(Output("output1", "children"), Input("input", "value"))
def update_output1(new_value):
return new_value
oidc = OIDCAuth(app, secret_key="Test")
# Add a first provider
oidc.register_provider(
"idp1",
token_endpoint_auth_method="client_secret_post",
client_id="<client-id>",
client_secret="<client-secret>",
server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration",
)
# Add a second provider
oidc.register_provider(
"idp2",
token_endpoint_auth_method="client_secret_post",
client_id="<client-id2>",
client_secret="<client-secret2>",
server_metadata_url="https://idp2.com/oidc/2/.well-known/openid-configuration",
)
dash_thread_server(app)
base_url = dash_thread_server.url.rstrip("/")
assert requests.get(base_url).status_code == 400
# Login with IDP1
assert requests.get(base_url + "/oidc/idp1/login").status_code == 200
# Logout
assert requests.get(base_url + "/oidc/logout").status_code == 200
assert requests.get(base_url).status_code == 400
# Login with IDP2
assert requests.get(base_url + "/oidc/idp2/login").status_code == 200
dash_br.driver.get(base_url + "/oidc/idp2/login")
dash_br.driver.get(base_url)
dash_br.wait_for_text_to_equal("#output1", "initial value")