|
6 | 6 |
|
7 | 7 | import json |
8 | 8 | import html |
9 | | -from http import HTTPStatus |
| 9 | +import logging |
10 | 10 | from typing import Optional, cast |
11 | 11 |
|
12 | 12 | from fastapi import FastAPI, Form, Request, Depends, Response |
|
29 | 29 | app.add_middleware(SessionMiddleware) |
30 | 30 |
|
31 | 31 |
|
| 32 | +logging.basicConfig( |
| 33 | + level=logging.DEBUG, |
| 34 | + format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s", |
| 35 | + datefmt="%H:%M:%S", |
| 36 | +) |
| 37 | + |
| 38 | + |
32 | 39 | async def get_auth_server() -> AuthServer: |
33 | 40 | """ |
34 | 41 | initialize oauth authorization server |
@@ -74,8 +81,24 @@ async def authorize( |
74 | 81 | oauth2 authorization endpoint using aioauth |
75 | 82 | """ |
76 | 83 | oauthreq = await to_request(request) |
| 84 | + user = request.session.get("user", None) |
| 85 | + |
77 | 86 | response = await oauth.create_authorization_response(oauthreq) |
78 | | - if response.status_code == HTTPStatus.UNAUTHORIZED: |
| 87 | + |
| 88 | + # A demonstration example of request validation before checking the user's credentials. |
| 89 | + # See a discussion here: https://github.com/aliev/aioauth/issues/101 |
| 90 | + if response.status_code >= 400: |
| 91 | + content = f""" |
| 92 | + <html> |
| 93 | + <body> |
| 94 | + <h3>{response.content['error']}</h3> |
| 95 | + <p style="color: red">{response.content['description']}</p> |
| 96 | + </body> |
| 97 | + </html> |
| 98 | + """ |
| 99 | + return HTMLResponse(content, status_code=response.status_code) |
| 100 | + |
| 101 | + if user is None: |
79 | 102 | request.session["oauth"] = oauthreq |
80 | 103 | return RedirectResponse("/login") |
81 | 104 | return to_response(response) |
@@ -155,18 +178,29 @@ async def approve(request: Request): |
155 | 178 | if "user" not in request.session: |
156 | 179 | redirect = request.url_for("login") |
157 | 180 | return RedirectResponse(redirect) |
158 | | - oauthreq: OAuthRequest = request.session["oauth"] |
159 | | - content = f""" |
160 | | -<html> |
161 | | - <body> |
162 | | - <h3>{oauthreq.query.client_id} would like permissions.</h3> |
163 | | - <form method="POST"> |
164 | | - <button name="approval" value="0" type="submit">Deny</button> |
165 | | - <button name="approval" value="1" type="submit">Approve</button> |
166 | | - </form> |
167 | | - </body> |
168 | | -</html> |
169 | | - """ |
| 181 | + |
| 182 | + oauth = request.session.get("oauth", None) |
| 183 | + if oauth: |
| 184 | + oauthreq: OAuthRequest = request.session["oauth"] |
| 185 | + content = f""" |
| 186 | + <html> |
| 187 | + <body> |
| 188 | + <h3>{oauthreq.query.client_id} would like permissions.</h3> |
| 189 | + <form method="POST"> |
| 190 | + <button name="approval" value="0" type="submit">Deny</button> |
| 191 | + <button name="approval" value="1" type="submit">Approve</button> |
| 192 | + </form> |
| 193 | + </body> |
| 194 | + </html> |
| 195 | + """ |
| 196 | + else: |
| 197 | + content = f""" |
| 198 | + <html> |
| 199 | + <body> |
| 200 | + <h3>Hello, {request.session['user'].username}.</h3> |
| 201 | + </body> |
| 202 | + </html> |
| 203 | + """ |
170 | 204 | return HTMLResponse(content) |
171 | 205 |
|
172 | 206 |
|
|
0 commit comments