Skip to content

Commit 060ba8b

Browse files
committed
[Auth] Cover the auth middleware with tests
1 parent b91a5b7 commit 060ba8b

File tree

4 files changed

+291
-13
lines changed

4 files changed

+291
-13
lines changed

cronback-services/src/api/auth.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ impl From<AuthError> for ApiError {
4444
}
4545
}
4646

47+
#[derive(Clone)]
4748
pub struct Authenticator {
4849
store: AuthStore,
4950
}
@@ -190,7 +191,7 @@ impl FromStr for SecretApiKey {
190191
}
191192

192193
impl SecretApiKey {
193-
fn generate() -> Self {
194+
pub fn generate() -> Self {
194195
Self {
195196
key_id: Uuid::new_v4().simple().to_string(),
196197
plain_secret: Uuid::new_v4().simple().to_string(),

cronback-services/src/api/auth_middleware.rs

Lines changed: 284 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,34 @@
11
use std::sync::Arc;
22

3-
use axum::extract::State;
3+
use axum::extract::{FromRef, State};
44
use axum::http::{self, HeaderMap, HeaderValue, Request};
55
use axum::middleware::Next;
66
use axum::response::IntoResponse;
77
use lib::prelude::*;
88

9-
use super::auth::{AuthError, SecretApiKey};
9+
use super::auth::{AuthError, Authenticator, SecretApiKey};
1010
use super::errors::ApiError;
1111
use super::AppState;
1212

1313
const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of";
1414

15+
// Partial state from the main app state to facilitate writing tests for the
16+
// middleware.
17+
#[derive(Clone)]
18+
pub struct AuthenticationState {
19+
authenticator: Authenticator,
20+
config: super::config::ApiSvcConfig,
21+
}
22+
23+
impl FromRef<Arc<AppState>> for AuthenticationState {
24+
fn from_ref(input: &Arc<AppState>) -> Self {
25+
Self {
26+
authenticator: input.authenticator.clone(),
27+
config: input.context.service_config(),
28+
}
29+
}
30+
}
31+
1532
enum AuthenticationStatus {
1633
Unauthenticated,
1734
Authenticated(ValidShardedId<ProjectId>),
@@ -61,14 +78,14 @@ fn get_auth_key(
6178
}
6279

6380
async fn get_auth_status<B>(
64-
state: &AppState,
81+
state: &AuthenticationState,
6582
req: &Request<B>,
6683
) -> Result<AuthenticationStatus, ApiError> {
6784
let auth_key = get_auth_key(req.headers())?;
6885
let Some(auth_key) = auth_key else {
6986
return Ok(AuthenticationStatus::Unauthenticated);
7087
};
71-
let config = state.context.service_config();
88+
let config = &state.config;
7289
let admin_keys = &config.admin_api_keys;
7390
if admin_keys.contains(&auth_key) {
7491
let project: Option<ValidShardedId<ProjectId>> = req
@@ -98,7 +115,10 @@ async fn get_auth_status<B>(
98115
return Ok(AuthenticationStatus::Unauthenticated);
99116
};
100117

101-
let project = state.authenicator.authenticate(&user_provided_secret).await;
118+
let project = state
119+
.authenticator
120+
.authenticate(&user_provided_secret)
121+
.await;
102122
match project {
103123
| Ok(project_id) => Ok(AuthenticationStatus::Authenticated(project_id)),
104124
| Err(AuthError::AuthFailed(_)) => {
@@ -178,11 +198,11 @@ pub async fn ensure_admin<B>(
178198
/// of the other "ensure_*" middlewares in this module to enforce the expected
179199
/// AuthenticationStatus for a certain route.
180200
pub async fn authenticate<B>(
181-
State(state): State<Arc<AppState>>,
201+
State(state): State<AuthenticationState>,
182202
mut req: Request<B>,
183203
next: Next<B>,
184204
) -> Result<impl IntoResponse, ApiError> {
185-
let auth_status = get_auth_status(state.as_ref(), &req).await?;
205+
let auth_status = get_auth_status(&state, &req).await?;
186206

187207
let project_id = auth_status.project_id();
188208
req.extensions_mut().insert(auth_status);
@@ -200,3 +220,260 @@ pub async fn authenticate<B>(
200220

201221
Ok(resp)
202222
}
223+
224+
#[cfg(test)]
225+
mod tests {
226+
227+
use std::collections::HashSet;
228+
use std::fmt::Debug;
229+
230+
use axum::routing::get;
231+
use axum::{middleware, Router};
232+
use cronback_api_model::admin::CreateAPIkeyRequest;
233+
use hyper::{Body, StatusCode};
234+
use tower::ServiceExt;
235+
236+
use super::*;
237+
use crate::api::auth_store::AuthStore;
238+
use crate::api::config::ApiSvcConfig;
239+
use crate::api::ApiService;
240+
241+
async fn make_state() -> AuthenticationState {
242+
let mut set = HashSet::new();
243+
set.insert("adminkey1".to_string());
244+
set.insert("adminkey2".to_string());
245+
246+
let config = ApiSvcConfig {
247+
address: String::new(),
248+
port: 123,
249+
database_uri: String::new(),
250+
admin_api_keys: set,
251+
log_request_body: false,
252+
log_response_body: false,
253+
};
254+
255+
let db = ApiService::in_memory_database().await.unwrap();
256+
let auth_store = AuthStore::new(db);
257+
let authenticator = Authenticator::new(auth_store);
258+
259+
AuthenticationState {
260+
authenticator,
261+
config,
262+
}
263+
}
264+
265+
struct TestInput {
266+
app: Router,
267+
auth_header: Option<String>,
268+
on_behalf_on_header: Option<String>,
269+
expected_status: StatusCode,
270+
}
271+
272+
impl Debug for TestInput {
273+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274+
f.debug_struct("TestInput")
275+
.field("auth_header", &self.auth_header)
276+
.field("on_behalf_on_header", &self.on_behalf_on_header)
277+
.field("expected_status", &self.expected_status)
278+
.finish()
279+
}
280+
}
281+
282+
struct TestExpectations {
283+
unauthenticated: StatusCode,
284+
authenticated: StatusCode,
285+
admin_no_project: StatusCode,
286+
admin_with_project: StatusCode,
287+
unknown_secret_key: StatusCode,
288+
}
289+
290+
async fn run_tests(
291+
app: Router,
292+
state: AuthenticationState,
293+
expectations: TestExpectations,
294+
) -> anyhow::Result<()> {
295+
// Define one project and generate a key for it.
296+
let prj1 = ProjectId::generate();
297+
let key = state
298+
.authenticator
299+
.gen_key(
300+
CreateAPIkeyRequest {
301+
key_name: "test".to_string(),
302+
metadata: Default::default(),
303+
},
304+
&prj1,
305+
)
306+
.await?;
307+
308+
let inputs = vec![
309+
// Unauthenticated user
310+
TestInput {
311+
app: app.clone(),
312+
auth_header: None,
313+
on_behalf_on_header: None,
314+
expected_status: expectations.unauthenticated,
315+
},
316+
// Authenticated user
317+
TestInput {
318+
app: app.clone(),
319+
auth_header: Some(format!("Bearer {}", key.unsafe_to_string())),
320+
on_behalf_on_header: None,
321+
expected_status: expectations.authenticated,
322+
},
323+
// Admin without project
324+
TestInput {
325+
app: app.clone(),
326+
auth_header: Some("Bearer adminkey1".to_string()),
327+
on_behalf_on_header: None,
328+
expected_status: expectations.admin_no_project,
329+
},
330+
// Admin with project
331+
TestInput {
332+
app: app.clone(),
333+
auth_header: Some("Bearer adminkey1".to_string()),
334+
on_behalf_on_header: Some(prj1.to_string()),
335+
expected_status: expectations.admin_with_project,
336+
},
337+
// Unknown secret key
338+
TestInput {
339+
app: app.clone(),
340+
auth_header: Some(format!(
341+
"Bearer {}",
342+
SecretApiKey::generate().unsafe_to_string()
343+
)),
344+
on_behalf_on_header: Some(prj1.to_string()),
345+
expected_status: expectations.unknown_secret_key,
346+
},
347+
// Malformed secret key should be treated as an unknown secret key
348+
TestInput {
349+
app: app.clone(),
350+
auth_header: Some("Bearer wrong key".to_string()),
351+
on_behalf_on_header: Some("wrong_project".to_string()),
352+
expected_status: expectations.unknown_secret_key,
353+
},
354+
// Malformed authorization header
355+
TestInput {
356+
app: app.clone(),
357+
auth_header: Some(format!("Token {}", key.unsafe_to_string())),
358+
on_behalf_on_header: Some(prj1.to_string()),
359+
expected_status: StatusCode::BAD_REQUEST,
360+
},
361+
// Malformed on-behalf-on project id
362+
TestInput {
363+
app: app.clone(),
364+
auth_header: Some("Bearer adminkey1".to_string()),
365+
on_behalf_on_header: Some("wrong_project".to_string()),
366+
expected_status: StatusCode::BAD_REQUEST,
367+
},
368+
];
369+
370+
for input in inputs {
371+
let input_str = format!("{:?}", input);
372+
373+
let mut req = Request::builder();
374+
if let Some(v) = input.auth_header {
375+
req = req.header("Authorization", v);
376+
}
377+
if let Some(v) = input.on_behalf_on_header {
378+
req = req.header(ON_BEHALF_OF_HEADER_NAME, v);
379+
}
380+
381+
let resp = input
382+
.app
383+
.oneshot(req.uri("/").body(Body::empty()).unwrap())
384+
.await?;
385+
386+
assert_eq!(
387+
resp.status(),
388+
input.expected_status,
389+
"Input: {}",
390+
input_str
391+
);
392+
}
393+
Ok(())
394+
}
395+
396+
#[tokio::test]
397+
async fn test_ensure_authenticated() -> anyhow::Result<()> {
398+
let state = make_state().await;
399+
400+
let app = Router::new()
401+
.route("/", get(|| async { "Hello, World!" }))
402+
.layer(middleware::from_fn(super::ensure_authenticated))
403+
.layer(middleware::from_fn_with_state(
404+
state.clone(),
405+
super::authenticate,
406+
));
407+
408+
run_tests(
409+
app,
410+
state,
411+
TestExpectations {
412+
unauthenticated: StatusCode::UNAUTHORIZED,
413+
authenticated: StatusCode::OK,
414+
admin_no_project: StatusCode::BAD_REQUEST,
415+
admin_with_project: StatusCode::OK,
416+
unknown_secret_key: StatusCode::UNAUTHORIZED,
417+
},
418+
)
419+
.await?;
420+
421+
Ok(())
422+
}
423+
424+
#[tokio::test]
425+
async fn test_ensure_admin() -> anyhow::Result<()> {
426+
let state = make_state().await;
427+
428+
let app = Router::new()
429+
.route("/", get(|| async { "Hello, World!" }))
430+
.layer(middleware::from_fn(super::ensure_admin))
431+
.layer(middleware::from_fn_with_state(
432+
state.clone(),
433+
super::authenticate,
434+
));
435+
436+
run_tests(
437+
app,
438+
state,
439+
TestExpectations {
440+
unauthenticated: StatusCode::UNAUTHORIZED,
441+
authenticated: StatusCode::FORBIDDEN,
442+
admin_no_project: StatusCode::OK,
443+
admin_with_project: StatusCode::OK,
444+
unknown_secret_key: StatusCode::UNAUTHORIZED,
445+
},
446+
)
447+
.await?;
448+
449+
Ok(())
450+
}
451+
452+
#[tokio::test]
453+
async fn test_ensure_admin_for_project() -> anyhow::Result<()> {
454+
let state = make_state().await;
455+
456+
let app = Router::new()
457+
.route("/", get(|| async { "Hello, World!" }))
458+
.layer(middleware::from_fn(super::ensure_admin_for_project))
459+
.layer(middleware::from_fn_with_state(
460+
state.clone(),
461+
super::authenticate,
462+
));
463+
464+
run_tests(
465+
app,
466+
state,
467+
TestExpectations {
468+
unauthenticated: StatusCode::UNAUTHORIZED,
469+
authenticated: StatusCode::FORBIDDEN,
470+
admin_no_project: StatusCode::BAD_REQUEST,
471+
admin_with_project: StatusCode::OK,
472+
unknown_secret_key: StatusCode::UNAUTHORIZED,
473+
},
474+
)
475+
.await?;
476+
477+
Ok(())
478+
}
479+
}

cronback-services/src/api/handlers/admin/api_keys.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub(crate) async fn create(
2424
Extension(project): Extension<ValidShardedId<ProjectId>>,
2525
ValidatedJson(req): ValidatedJson<CreateAPIkeyRequest>,
2626
) -> Result<Json<CreateAPIKeyResponse>, ApiError> {
27-
let key = state.authenicator.gen_key(req, &project).await?;
27+
let key = state.authenticator.gen_key(req, &project).await?;
2828

2929
// This is the only legitimate place where this function should be used.
3030
let key_str = key.unsafe_to_string();
@@ -38,7 +38,7 @@ pub(crate) async fn list(
3838
Extension(project): Extension<ValidShardedId<ProjectId>>,
3939
) -> Result<Paginated<ApiKey>, ApiError> {
4040
let keys = state
41-
.authenicator
41+
.authenticator
4242
.list_keys(&project)
4343
.await
4444
.map_err(|e| AppStateError::DatabaseError(e.to_string()))?
@@ -72,7 +72,7 @@ pub(crate) async fn revoke(
7272
Extension(project): Extension<ValidShardedId<ProjectId>>,
7373
) -> Result<StatusCode, ApiError> {
7474
let deleted = state
75-
.authenicator
75+
.authenticator
7676
.revoke_key(&id, &project)
7777
.await
7878
.map_err(|e| AppStateError::DatabaseError(e.to_string()))?;

0 commit comments

Comments
 (0)