-
Notifications
You must be signed in to change notification settings - Fork 2
/
external_oidc_into_oauth2.py
201 lines (184 loc) · 7.13 KB
/
external_oidc_into_oauth2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import base64
from dataclasses import dataclass
from datetime import datetime, timedelta
import secrets
import httpx
from jose import jwk, jwt
from starlette.applications import Starlette
from starlette.responses import JSONResponse, HTMLResponse
from starlette.routing import Route
# jwt_sketch is a local module in this directory with a toy JWT implementation.
from jwt_sketch import create_jwt, read_jwt
from http_basic_into_oauth2 import authenticate, data, refresh
BASE_URL = "http://localhost:8000"
WEB_APP_URL = "http://localhost:8001/"
SIMPLE_OIDC_BASE_URL = "http://localhost:9000"
AUTH_ENDPOINT = f"{SIMPLE_OIDC_BASE_URL}/auth"
TOKEN_ENDPOINT = f"{SIMPLE_OIDC_BASE_URL}/token"
CLIENT_ID = "example_client_id"
CLIENT_SECRET = "example_client_secret"
# When the simple-oidc-provider starts, it generates fresh random certs.
# Downlaod them here. In a real application, this would be configured separately.
KEYS = httpx.get(f"{SIMPLE_OIDC_BASE_URL}/certs").json()["keys"]
for key in KEYS:
key["alg"] = "RS256"
authorization_uri = httpx.URL(
AUTH_ENDPOINT,
params={
"client_id": CLIENT_ID,
"response_type": "code",
"scope": "openid",
"redirect_uri": f"{BASE_URL}/device_code_callback",
}
)
async def code(request):
code = request.query_params["code"]
username = exchange_code_for_username(code, WEB_APP_URL)
access_token = create_token(
{"sub": username, "type": "access"},
# lifetime=10 * 60 # 10 minutes
lifetime=10 # 10 seconds
)
refresh_token = create_token(
{"sub": username, "type": "refresh"},
lifetime=14 * 24 * 60 * 60 # 2 weeks
)
return JSONResponse(
{"refresh_token": refresh_token, "access_token": access_token}
)
def exchange_code_for_username(code, redirect_uri):
auth_value = base64.b64encode(f"{CLIENT_ID}:{CLIENT_SECRET}".encode()).decode()
response = httpx.post(
url=TOKEN_ENDPOINT,
data={
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
"code": code,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
},
headers={"Authorization": f"Basic {auth_value}"},
)
response.raise_for_status()
response_body = response.json()
id_token = response_body["id_token"]
access_token = response_body["access_token"]
# Verify that response is from the trusted server.
unverified = jwt.get_unverified_header(id_token)
kid = unverified["kid"]
for candidate_key in KEYS:
if candidate_key["kid"] == kid:
key = jwk.construct(candidate_key)
break
else:
raise Exception(f"Could not find kid {kid} among {key['kid'] for key in KEYS}")
verified_body = jwt.decode(
id_token, key, access_token=access_token, audience=CLIENT_ID
)
username = verified_body["sub"]
return username
@dataclass
class PendingSession:
user_code: str
device_code: str
deadline: datetime
username: str = None
PENDING_SESSIONS = [] # placeholder for a proper database
def unauthorized(message="..."):
return JSONResponse({"error": message}, status_code=401)
def create_token(payload, lifetime):
payload["exp"] = int((datetime.now() + timedelta(seconds=lifetime)).timestamp())
return create_jwt(payload)
async def authorize(request):
user_code = secrets.token_hex(4).upper() # 8-digit code
device_code = secrets.token_hex(32)
deadline = datetime.now() + timedelta(minutes=15)
pending_session = PendingSession(
user_code=user_code, device_code=device_code, deadline=deadline
)
PENDING_SESSIONS.append(pending_session)
print(f"Created {pending_session}")
verification_uri = f"{BASE_URL}/token"
return JSONResponse(
{
"authorization_uri": str(authorization_uri),
"verification_uri": str(verification_uri),
"interval": 2, # seconds
"device_code": device_code,
"expires_in": 15 * 60, # seconds
"user_code": user_code,
}
)
async def device_code_callback(request):
code = request.query_params["code"]
return HTMLResponse(f"""
<html>
<body>
<form action="{BASE_URL}/device_code_form" method="post">
<label for="user_code">Enter code</label>
<input type="text" id="user_code" name="user_code" />
<input type="hidden" id="code" name="code" value="{code}" />
<input type="submit" value="Enter" />
</form>
</body>
</html>""")
async def handle_device_code_form(request):
# The identity provider calls this route via a redirect the user's browser.
# Here in the server, contact the identity provider with the provided code,
# and exchange it for information about the user.
form_data = await request.form()
redirect_uri = f"{BASE_URL}/device_code_callback"
username = exchange_code_for_username(form_data["code"], redirect_uri)
# Update the pending session with the username from the identity provider.
for pending_session in PENDING_SESSIONS:
if pending_session.user_code == form_data["user_code"]:
pending_session.username = username
print(f"Verified {pending_session}")
status_code = 200
message = "And there was much rejoicing!"
break
else:
status_code = 401
message = "Fail!"
return HTMLResponse(f"<html><body>{message}</body></html>", status_code=status_code)
async def token(request):
# Is there a pending session for this device code? Has it been verified yet?
form_data = await request.form()
device_code = form_data["device_code"]
for pending_session in PENDING_SESSIONS:
if pending_session.deadline < datetime.now():
PENDING_SESSIONS.remove(pending_session)
print(f"Expired {pending_session}")
continue
if pending_session.device_code == form_data["device_code"]:
if pending_session.username is None:
return unauthorized("pending")
# The pending session for this device code is verified!
# Return some tokens below.
PENDING_SESSIONS.remove(pending_session)
print(f"Used {pending_session}")
break
else:
return unauthorized("unrecognized device code -- maybe expired")
access_token = create_token(
{"sub": pending_session.username, "type": "access"},
# lifetime=10 * 60 # 10 minutes
lifetime=10 # 10 seconds
)
refresh_token = create_token(
{"sub": pending_session.username, "type": "refresh"},
lifetime=14 * 24 * 60 * 60 # 2 weeks
)
return JSONResponse(
{"refresh_token": refresh_token, "access_token": access_token}
)
routes = [
Route("/data", data, methods=["GET"]),
Route("/authorize", authorize, methods=["POST"]),
Route("/code", code, methods=["GET"]),
Route("/device_code_callback", device_code_callback, methods=["GET"]),
Route("/device_code_form", handle_device_code_form, methods=["POST"]),
Route("/token", token, methods=["POST"]),
Route("/refresh", refresh, methods=["POST"])
]
app = Starlette(routes=routes)