Skip to content

Commit 7352502

Browse files
committed
fix: explicitly show the response error
1 parent e380fe2 commit 7352502

5 files changed

Lines changed: 73 additions & 47 deletions

File tree

examples/fastapi_example.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import json
88
import html
9-
from http import HTTPStatus
9+
import logging
1010
from typing import Optional, cast
1111

1212
from fastapi import FastAPI, Form, Request, Depends, Response
@@ -29,6 +29,13 @@
2929
app.add_middleware(SessionMiddleware)
3030

3131

32+
logging.basicConfig(
33+
level=logging.DEBUG,
34+
format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s",
35+
datefmt="%H:%M:%S",
36+
)
37+
38+
3239
async def get_auth_server() -> AuthServer:
3340
"""
3441
initialize oauth authorization server
@@ -74,8 +81,24 @@ async def authorize(
7481
oauth2 authorization endpoint using aioauth
7582
"""
7683
oauthreq = await to_request(request)
84+
user = request.session.get("user", None)
85+
7786
response = await oauth.create_authorization_response(oauthreq)
78-
if response.status_code == HTTPStatus.UNAUTHORIZED:
87+
88+
# A demonstration example of request validation before checking the user's credentials.
89+
# See a discussion here: https://github.com/aliev/aioauth/issues/101
90+
if response.status_code >= 400:
91+
content = f"""
92+
<html>
93+
<body>
94+
<h3>{response.content['error']}</h3>
95+
<p style="color: red">{response.content['description']}</p>
96+
</body>
97+
</html>
98+
"""
99+
return HTMLResponse(content, status_code=response.status_code)
100+
101+
if user is None:
79102
request.session["oauth"] = oauthreq
80103
return RedirectResponse("/login")
81104
return to_response(response)
@@ -155,18 +178,29 @@ async def approve(request: Request):
155178
if "user" not in request.session:
156179
redirect = request.url_for("login")
157180
return RedirectResponse(redirect)
158-
oauthreq: OAuthRequest = request.session["oauth"]
159-
content = f"""
160-
<html>
161-
<body>
162-
<h3>{oauthreq.query.client_id} would like permissions.</h3>
163-
<form method="POST">
164-
<button name="approval" value="0" type="submit">Deny</button>
165-
<button name="approval" value="1" type="submit">Approve</button>
166-
</form>
167-
</body>
168-
</html>
169-
"""
181+
182+
oauth = request.session.get("oauth", None)
183+
if oauth:
184+
oauthreq: OAuthRequest = request.session["oauth"]
185+
content = f"""
186+
<html>
187+
<body>
188+
<h3>{oauthreq.query.client_id} would like permissions.</h3>
189+
<form method="POST">
190+
<button name="approval" value="0" type="submit">Deny</button>
191+
<button name="approval" value="1" type="submit">Approve</button>
192+
</form>
193+
</body>
194+
</html>
195+
"""
196+
else:
197+
content = f"""
198+
<html>
199+
<body>
200+
<h3>Hello, {request.session['user'].username}.</h3>
201+
</body>
202+
</html>
203+
"""
170204
return HTMLResponse(content)
171205

172206

examples/shared/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"AuthServer",
2121
"BackendStore",
2222
"engine",
23-
"config",
23+
"app_config",
2424
"settings",
2525
"try_login",
2626
"lifespan",
@@ -32,8 +32,8 @@
3232
"sqlite+aiosqlite:///:memory:", echo=False, future=True
3333
)
3434

35-
config = load_config(CONFIG_PATH)
36-
settings = config.settings
35+
app_config = load_config(CONFIG_PATH)
36+
settings = app_config.settings
3737

3838

3939
async def try_login(username: str, password: str) -> Optional[User]:
@@ -59,9 +59,9 @@ async def lifespan(*_):
5959
await conn.run_sync(SQLModel.metadata.create_all)
6060
# create test records
6161
async with AsyncSession(engine) as session:
62-
for user in config.fixtures.users:
62+
for user in app_config.fixtures.users:
6363
session.add(user)
64-
for client in config.fixtures.clients:
64+
for client in app_config.fixtures.clients:
6565
session.add(client)
6666
await session.commit()
6767
yield

examples/shared/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,6 @@
1010
from .models import User, Client
1111

1212

13-
def load_config(fpath: str) -> "Config":
14-
"""load configuration from filepath"""
15-
with open(fpath, "r") as f:
16-
json = f.read()
17-
return Config.model_validate_json(json)
18-
19-
2013
class Fixtures(BaseModel):
2114
users: List[User]
2215
clients: List[Client]
@@ -25,3 +18,10 @@ class Fixtures(BaseModel):
2518
class Config(BaseModel):
2619
fixtures: Fixtures
2720
settings: Settings
21+
22+
23+
def load_config(fpath: str) -> Config:
24+
"""load configuration from filepath"""
25+
with open(fpath, "r") as f:
26+
json = f.read()
27+
return Config.model_validate_json(json)

examples/shared/storage.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ async def get_authorization_code(
113113
) -> Optional[AuthorizationCode]:
114114
""" """
115115
async with self.session:
116-
sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id)
116+
sql = (
117+
select(AuthCodeTable)
118+
.where(AuthCodeTable.client_id == client_id)
119+
.where(AuthCodeTable.code == code)
120+
)
117121
result = (await self.session.exec(sql)).one_or_none()
118122
if result is not None:
119123
return AuthorizationCode(
@@ -138,7 +142,11 @@ async def delete_authorization_code(
138142
) -> None:
139143
""" """
140144
async with self.session:
141-
sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id)
145+
sql = (
146+
select(AuthCodeTable)
147+
.where(AuthCodeTable.client_id == client_id)
148+
.where(AuthCodeTable.code == code)
149+
)
142150
result = (await self.session.exec(sql)).one()
143151
await self.session.delete(result)
144152
await self.session.commit()

tests/oidc/core/test_flow.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from http import HTTPStatus
2-
from typing import Optional
32

43
import pytest
54

@@ -8,27 +7,12 @@
87
generate_token,
98
)
109

11-
from tests.classes import User
1210
from tests.utils import check_request_validators
1311

1412

1513
@pytest.mark.asyncio
16-
@pytest.mark.parametrize(
17-
"user, expected_status_code",
18-
[
19-
("username", HTTPStatus.FOUND),
20-
(None, HTTPStatus.FOUND),
21-
],
22-
)
23-
async def test_authorization_endpoint_allows_prompt_query_param(
24-
expected_status_code: HTTPStatus,
25-
user: Optional[User],
26-
context_factory,
27-
):
28-
if user is None:
29-
context = context_factory()
30-
else:
31-
context = context_factory(users={user: "password"})
14+
async def test_authorization_endpoint_allows_prompt_query_param(context_factory):
15+
context = context_factory()
3216
server = context.server
3317
client = context.clients[0]
3418
client_id = client.client_id
@@ -52,4 +36,4 @@ async def test_authorization_endpoint_allows_prompt_query_param(
5236
await check_request_validators(request, server.create_authorization_response)
5337

5438
response = await server.create_authorization_response(request)
55-
assert response.status_code == expected_status_code
39+
assert response.status_code == HTTPStatus.FOUND

0 commit comments

Comments
 (0)