1
1
use std:: sync:: Arc ;
2
2
3
- use axum:: extract:: State ;
3
+ use axum:: extract:: { FromRef , State } ;
4
4
use axum:: http:: { self , HeaderMap , HeaderValue , Request } ;
5
5
use axum:: middleware:: Next ;
6
6
use axum:: response:: IntoResponse ;
7
7
use lib:: prelude:: * ;
8
8
9
- use super :: auth:: { AuthError , SecretApiKey } ;
9
+ use super :: auth:: { AuthError , Authenticator , SecretApiKey } ;
10
10
use super :: errors:: ApiError ;
11
11
use super :: AppState ;
12
12
13
13
const ON_BEHALF_OF_HEADER_NAME : & str = "X-On-Behalf-Of" ;
14
14
15
+ // Partial state from the main app state to facilitate writing tests for the
16
+ // middleware.
17
+ #[ derive( Clone ) ]
18
+ pub struct AuthenticationState {
19
+ authenticator : Authenticator ,
20
+ config : super :: config:: ApiSvcConfig ,
21
+ }
22
+
23
+ impl FromRef < Arc < AppState > > for AuthenticationState {
24
+ fn from_ref ( input : & Arc < AppState > ) -> Self {
25
+ Self {
26
+ authenticator : input. authenticator . clone ( ) ,
27
+ config : input. context . service_config ( ) ,
28
+ }
29
+ }
30
+ }
31
+
15
32
enum AuthenticationStatus {
16
33
Unauthenticated ,
17
34
Authenticated ( ValidShardedId < ProjectId > ) ,
@@ -61,14 +78,14 @@ fn get_auth_key(
61
78
}
62
79
63
80
async fn get_auth_status < B > (
64
- state : & AppState ,
81
+ state : & AuthenticationState ,
65
82
req : & Request < B > ,
66
83
) -> Result < AuthenticationStatus , ApiError > {
67
84
let auth_key = get_auth_key ( req. headers ( ) ) ?;
68
85
let Some ( auth_key) = auth_key else {
69
86
return Ok ( AuthenticationStatus :: Unauthenticated ) ;
70
87
} ;
71
- let config = state. context . service_config ( ) ;
88
+ let config = & state. config ;
72
89
let admin_keys = & config. admin_api_keys ;
73
90
if admin_keys. contains ( & auth_key) {
74
91
let project: Option < ValidShardedId < ProjectId > > = req
@@ -98,7 +115,10 @@ async fn get_auth_status<B>(
98
115
return Ok ( AuthenticationStatus :: Unauthenticated ) ;
99
116
} ;
100
117
101
- let project = state. authenicator . authenticate ( & user_provided_secret) . await ;
118
+ let project = state
119
+ . authenticator
120
+ . authenticate ( & user_provided_secret)
121
+ . await ;
102
122
match project {
103
123
| Ok ( project_id) => Ok ( AuthenticationStatus :: Authenticated ( project_id) ) ,
104
124
| Err ( AuthError :: AuthFailed ( _) ) => {
@@ -178,11 +198,11 @@ pub async fn ensure_admin<B>(
178
198
/// of the other "ensure_*" middlewares in this module to enforce the expected
179
199
/// AuthenticationStatus for a certain route.
180
200
pub async fn authenticate < B > (
181
- State ( state) : State < Arc < AppState > > ,
201
+ State ( state) : State < AuthenticationState > ,
182
202
mut req : Request < B > ,
183
203
next : Next < B > ,
184
204
) -> Result < impl IntoResponse , ApiError > {
185
- let auth_status = get_auth_status ( state. as_ref ( ) , & req) . await ?;
205
+ let auth_status = get_auth_status ( & state, & req) . await ?;
186
206
187
207
let project_id = auth_status. project_id ( ) ;
188
208
req. extensions_mut ( ) . insert ( auth_status) ;
@@ -200,3 +220,260 @@ pub async fn authenticate<B>(
200
220
201
221
Ok ( resp)
202
222
}
223
+
224
+ #[ cfg( test) ]
225
+ mod tests {
226
+
227
+ use std:: collections:: HashSet ;
228
+ use std:: fmt:: Debug ;
229
+
230
+ use axum:: routing:: get;
231
+ use axum:: { middleware, Router } ;
232
+ use cronback_api_model:: admin:: CreateAPIkeyRequest ;
233
+ use hyper:: { Body , StatusCode } ;
234
+ use tower:: ServiceExt ;
235
+
236
+ use super :: * ;
237
+ use crate :: api:: auth_store:: AuthStore ;
238
+ use crate :: api:: config:: ApiSvcConfig ;
239
+ use crate :: api:: ApiService ;
240
+
241
+ async fn make_state ( ) -> AuthenticationState {
242
+ let mut set = HashSet :: new ( ) ;
243
+ set. insert ( "adminkey1" . to_string ( ) ) ;
244
+ set. insert ( "adminkey2" . to_string ( ) ) ;
245
+
246
+ let config = ApiSvcConfig {
247
+ address : String :: new ( ) ,
248
+ port : 123 ,
249
+ database_uri : String :: new ( ) ,
250
+ admin_api_keys : set,
251
+ log_request_body : false ,
252
+ log_response_body : false ,
253
+ } ;
254
+
255
+ let db = ApiService :: in_memory_database ( ) . await . unwrap ( ) ;
256
+ let auth_store = AuthStore :: new ( db) ;
257
+ let authenticator = Authenticator :: new ( auth_store) ;
258
+
259
+ AuthenticationState {
260
+ authenticator,
261
+ config,
262
+ }
263
+ }
264
+
265
+ struct TestInput {
266
+ app : Router ,
267
+ auth_header : Option < String > ,
268
+ on_behalf_on_header : Option < String > ,
269
+ expected_status : StatusCode ,
270
+ }
271
+
272
+ impl Debug for TestInput {
273
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
274
+ f. debug_struct ( "TestInput" )
275
+ . field ( "auth_header" , & self . auth_header )
276
+ . field ( "on_behalf_on_header" , & self . on_behalf_on_header )
277
+ . field ( "expected_status" , & self . expected_status )
278
+ . finish ( )
279
+ }
280
+ }
281
+
282
+ struct TestExpectations {
283
+ unauthenticated : StatusCode ,
284
+ authenticated : StatusCode ,
285
+ admin_no_project : StatusCode ,
286
+ admin_with_project : StatusCode ,
287
+ unknown_secret_key : StatusCode ,
288
+ }
289
+
290
+ async fn run_tests (
291
+ app : Router ,
292
+ state : AuthenticationState ,
293
+ expectations : TestExpectations ,
294
+ ) -> anyhow:: Result < ( ) > {
295
+ // Define one project and generate a key for it.
296
+ let prj1 = ProjectId :: generate ( ) ;
297
+ let key = state
298
+ . authenticator
299
+ . gen_key (
300
+ CreateAPIkeyRequest {
301
+ key_name : "test" . to_string ( ) ,
302
+ metadata : Default :: default ( ) ,
303
+ } ,
304
+ & prj1,
305
+ )
306
+ . await ?;
307
+
308
+ let inputs = vec ! [
309
+ // Unauthenticated user
310
+ TestInput {
311
+ app: app. clone( ) ,
312
+ auth_header: None ,
313
+ on_behalf_on_header: None ,
314
+ expected_status: expectations. unauthenticated,
315
+ } ,
316
+ // Authenticated user
317
+ TestInput {
318
+ app: app. clone( ) ,
319
+ auth_header: Some ( format!( "Bearer {}" , key. unsafe_to_string( ) ) ) ,
320
+ on_behalf_on_header: None ,
321
+ expected_status: expectations. authenticated,
322
+ } ,
323
+ // Admin without project
324
+ TestInput {
325
+ app: app. clone( ) ,
326
+ auth_header: Some ( "Bearer adminkey1" . to_string( ) ) ,
327
+ on_behalf_on_header: None ,
328
+ expected_status: expectations. admin_no_project,
329
+ } ,
330
+ // Admin with project
331
+ TestInput {
332
+ app: app. clone( ) ,
333
+ auth_header: Some ( "Bearer adminkey1" . to_string( ) ) ,
334
+ on_behalf_on_header: Some ( prj1. to_string( ) ) ,
335
+ expected_status: expectations. admin_with_project,
336
+ } ,
337
+ // Unknown secret key
338
+ TestInput {
339
+ app: app. clone( ) ,
340
+ auth_header: Some ( format!(
341
+ "Bearer {}" ,
342
+ SecretApiKey :: generate( ) . unsafe_to_string( )
343
+ ) ) ,
344
+ on_behalf_on_header: Some ( prj1. to_string( ) ) ,
345
+ expected_status: expectations. unknown_secret_key,
346
+ } ,
347
+ // Malformed secret key should be treated as an unknown secret key
348
+ TestInput {
349
+ app: app. clone( ) ,
350
+ auth_header: Some ( "Bearer wrong key" . to_string( ) ) ,
351
+ on_behalf_on_header: Some ( "wrong_project" . to_string( ) ) ,
352
+ expected_status: expectations. unknown_secret_key,
353
+ } ,
354
+ // Malformed authorization header
355
+ TestInput {
356
+ app: app. clone( ) ,
357
+ auth_header: Some ( format!( "Token {}" , key. unsafe_to_string( ) ) ) ,
358
+ on_behalf_on_header: Some ( prj1. to_string( ) ) ,
359
+ expected_status: StatusCode :: BAD_REQUEST ,
360
+ } ,
361
+ // Malformed on-behalf-on project id
362
+ TestInput {
363
+ app: app. clone( ) ,
364
+ auth_header: Some ( "Bearer adminkey1" . to_string( ) ) ,
365
+ on_behalf_on_header: Some ( "wrong_project" . to_string( ) ) ,
366
+ expected_status: StatusCode :: BAD_REQUEST ,
367
+ } ,
368
+ ] ;
369
+
370
+ for input in inputs {
371
+ let input_str = format ! ( "{:?}" , input) ;
372
+
373
+ let mut req = Request :: builder ( ) ;
374
+ if let Some ( v) = input. auth_header {
375
+ req = req. header ( "Authorization" , v) ;
376
+ }
377
+ if let Some ( v) = input. on_behalf_on_header {
378
+ req = req. header ( ON_BEHALF_OF_HEADER_NAME , v) ;
379
+ }
380
+
381
+ let resp = input
382
+ . app
383
+ . oneshot ( req. uri ( "/" ) . body ( Body :: empty ( ) ) . unwrap ( ) )
384
+ . await ?;
385
+
386
+ assert_eq ! (
387
+ resp. status( ) ,
388
+ input. expected_status,
389
+ "Input: {}" ,
390
+ input_str
391
+ ) ;
392
+ }
393
+ Ok ( ( ) )
394
+ }
395
+
396
+ #[ tokio:: test]
397
+ async fn test_ensure_authenticated ( ) -> anyhow:: Result < ( ) > {
398
+ let state = make_state ( ) . await ;
399
+
400
+ let app = Router :: new ( )
401
+ . route ( "/" , get ( || async { "Hello, World!" } ) )
402
+ . layer ( middleware:: from_fn ( super :: ensure_authenticated) )
403
+ . layer ( middleware:: from_fn_with_state (
404
+ state. clone ( ) ,
405
+ super :: authenticate,
406
+ ) ) ;
407
+
408
+ run_tests (
409
+ app,
410
+ state,
411
+ TestExpectations {
412
+ unauthenticated : StatusCode :: UNAUTHORIZED ,
413
+ authenticated : StatusCode :: OK ,
414
+ admin_no_project : StatusCode :: BAD_REQUEST ,
415
+ admin_with_project : StatusCode :: OK ,
416
+ unknown_secret_key : StatusCode :: UNAUTHORIZED ,
417
+ } ,
418
+ )
419
+ . await ?;
420
+
421
+ Ok ( ( ) )
422
+ }
423
+
424
+ #[ tokio:: test]
425
+ async fn test_ensure_admin ( ) -> anyhow:: Result < ( ) > {
426
+ let state = make_state ( ) . await ;
427
+
428
+ let app = Router :: new ( )
429
+ . route ( "/" , get ( || async { "Hello, World!" } ) )
430
+ . layer ( middleware:: from_fn ( super :: ensure_admin) )
431
+ . layer ( middleware:: from_fn_with_state (
432
+ state. clone ( ) ,
433
+ super :: authenticate,
434
+ ) ) ;
435
+
436
+ run_tests (
437
+ app,
438
+ state,
439
+ TestExpectations {
440
+ unauthenticated : StatusCode :: UNAUTHORIZED ,
441
+ authenticated : StatusCode :: FORBIDDEN ,
442
+ admin_no_project : StatusCode :: OK ,
443
+ admin_with_project : StatusCode :: OK ,
444
+ unknown_secret_key : StatusCode :: UNAUTHORIZED ,
445
+ } ,
446
+ )
447
+ . await ?;
448
+
449
+ Ok ( ( ) )
450
+ }
451
+
452
+ #[ tokio:: test]
453
+ async fn test_ensure_admin_for_project ( ) -> anyhow:: Result < ( ) > {
454
+ let state = make_state ( ) . await ;
455
+
456
+ let app = Router :: new ( )
457
+ . route ( "/" , get ( || async { "Hello, World!" } ) )
458
+ . layer ( middleware:: from_fn ( super :: ensure_admin_for_project) )
459
+ . layer ( middleware:: from_fn_with_state (
460
+ state. clone ( ) ,
461
+ super :: authenticate,
462
+ ) ) ;
463
+
464
+ run_tests (
465
+ app,
466
+ state,
467
+ TestExpectations {
468
+ unauthenticated : StatusCode :: UNAUTHORIZED ,
469
+ authenticated : StatusCode :: FORBIDDEN ,
470
+ admin_no_project : StatusCode :: BAD_REQUEST ,
471
+ admin_with_project : StatusCode :: OK ,
472
+ unknown_secret_key : StatusCode :: UNAUTHORIZED ,
473
+ } ,
474
+ )
475
+ . await ?;
476
+
477
+ Ok ( ( ) )
478
+ }
479
+ }
0 commit comments