From 939ad8f35b31a4d479f4275b785bf2c0b1d64b36 Mon Sep 17 00:00:00 2001 From: Shubhendu Ram Tripathi Date: Thu, 19 Feb 2026 12:36:08 +0530 Subject: [PATCH] Added support for SigV4 authentication Benefits: 1. Pluggable authentication - Authenticator trait allows OAuth2, SigV4, or custom auth 2. Backward compatible - Existing OAuth2 behavior preserved 3. Performance optimized - Signing key caching (valid for 24h) 4. Works with catalog implementations which use SigV4 Catalog implementation like MinIO use SigV4 authentication and it would be helpful for such cases. Signed-off-by: Shubhendu Ram Tripathi --- Cargo.lock | 2 + crates/catalog/rest/Cargo.toml | 2 + crates/catalog/rest/src/auth/mod.rs | 184 +++++++ crates/catalog/rest/src/auth/oauth2.rs | 377 +++++++++++++ crates/catalog/rest/src/auth/sigv4.rs | 706 +++++++++++++++++++++++++ crates/catalog/rest/src/catalog.rs | 49 ++ crates/catalog/rest/src/client.rs | 238 +++------ crates/catalog/rest/src/lib.rs | 2 + 8 files changed, 1406 insertions(+), 154 deletions(-) create mode 100644 crates/catalog/rest/src/auth/mod.rs create mode 100644 crates/catalog/rest/src/auth/oauth2.rs create mode 100644 crates/catalog/rest/src/auth/sigv4.rs diff --git a/Cargo.lock b/Cargo.lock index dfce67d7b5..e2e8d8525d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3505,6 +3505,7 @@ version = "0.8.0" dependencies = [ "async-trait", "chrono", + "hmac", "http 1.4.0", "iceberg", "iceberg_test_utils", @@ -3514,6 +3515,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", + "sha2", "tokio", "tracing", "typed-builder", diff --git a/crates/catalog/rest/Cargo.toml b/crates/catalog/rest/Cargo.toml index de72b6c61b..f10751986e 100644 --- a/crates/catalog/rest/Cargo.toml +++ b/crates/catalog/rest/Cargo.toml @@ -31,6 +31,7 @@ repository = { workspace = true } [dependencies] async-trait = { workspace = true } chrono = { workspace = true } +hmac = "0.12" http = { workspace = true } iceberg = { workspace = true } itertools = { workspace = true } @@ -38,6 +39,7 @@ reqwest = { workspace = true } serde = { workspace = true } serde_derive = { workspace = true } serde_json = { workspace = true } +sha2 = "0.10" tokio = { workspace = true, features = ["sync"] } tracing = { workspace = true } typed-builder = { workspace = true } diff --git a/crates/catalog/rest/src/auth/mod.rs b/crates/catalog/rest/src/auth/mod.rs new file mode 100644 index 0000000000..cfc92e508e --- /dev/null +++ b/crates/catalog/rest/src/auth/mod.rs @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Authentication module for the REST catalog. +//! +//! This module provides pluggable authentication mechanisms for the Iceberg REST catalog. +//! Currently supported authentication methods: +//! +//! - **OAuth2**: Bearer token authentication using OAuth2 client credentials flow +//! - **SigV4**: AWS Signature Version 4 authentication for AWS-compatible services +//! - **None**: No authentication (for development/testing) +//! +//! # Configuration +//! +//! Authentication is configured via catalog properties: +//! +//! ## OAuth2 (default) +//! ```text +//! token = "your-bearer-token" +//! // or +//! credential = "client_id:client_secret" +//! oauth2-server-uri = "https://auth.example.com/oauth/tokens" +//! scope = "catalog" +//! ``` +//! +//! ## SigV4 +//! ```text +//! auth-type = "sigv4" +//! s3.access-key-id = "AKIAIOSFODNN7EXAMPLE" +//! s3.secret-access-key = "wJalrXUtnFEMI/K7MDENG..." +//! s3.session-token = "..." // optional +//! s3.region = "us-east-1" +//! sigv4.service = "s3" // optional, defaults to "s3" +//! ``` + +mod oauth2; +mod sigv4; + +use std::fmt::Debug; + +use async_trait::async_trait; +use iceberg::Result; +pub use oauth2::OAuth2Authenticator; +use reqwest::Request; +pub use sigv4::{SigV4Authenticator, SigV4Credentials}; + +/// Authentication provider trait for REST catalog requests. +/// +/// Implementors of this trait provide a mechanism to authenticate HTTP requests +/// before they are sent to the REST catalog server. +/// +/// # Thread Safety +/// +/// Authenticators must be `Send + Sync` to allow concurrent request authentication. +/// +/// # Example +/// +/// ```ignore +/// use iceberg_catalog_rest::auth::{Authenticator, SigV4Authenticator, SigV4Credentials}; +/// +/// let credentials = SigV4Credentials::new("access_key", "secret_key", None); +/// let auth = SigV4Authenticator::new(credentials, "us-east-1", "s3"); +/// +/// // The authenticator will sign requests with AWS SigV4 +/// auth.authenticate(&mut request).await?; +/// ``` +#[async_trait] +pub trait Authenticator: Send + Sync + Debug { + /// Authenticate a request by modifying its headers. + /// + /// This method is called before each request is sent to the catalog server. + /// Implementations should add appropriate authentication headers to the request. + /// + /// # Arguments + /// + /// * `request` - The mutable request to authenticate + /// + /// # Returns + /// + /// * `Ok(())` if authentication was successful + /// * `Err(...)` if authentication failed (e.g., invalid credentials) + async fn authenticate(&self, request: &mut Request) -> Result<()>; + + /// Invalidate any cached credentials or tokens. + /// + /// This method is called when authentication fails and the client needs to + /// refresh credentials. Implementations should clear any cached tokens. + /// + /// The default implementation does nothing. + async fn invalidate(&self) -> Result<()> { + Ok(()) + } + + /// Regenerate credentials or tokens. + /// + /// This method is called to proactively refresh credentials before they expire. + /// Implementations should fetch new tokens or refresh credentials as needed. + /// + /// The default implementation does nothing. + async fn regenerate(&self) -> Result<()> { + Ok(()) + } + + /// Returns the authentication scheme name for logging and debugging. + fn scheme_name(&self) -> &'static str; + + /// Returns the current cached token, if any. + /// + /// This is primarily used for testing OAuth2 authentication flows. + /// For non-token-based authentication (SigV4, NoAuth), this returns None. + async fn get_token(&self) -> Option { + None + } +} + +/// No authentication provider. +/// +/// This authenticator does nothing and is used when no authentication is required. +#[derive(Debug, Clone, Default)] +pub struct NoAuth; + +#[async_trait] +impl Authenticator for NoAuth { + async fn authenticate(&self, _request: &mut Request) -> Result<()> { + Ok(()) + } + + fn scheme_name(&self) -> &'static str { + "none" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_no_auth_does_nothing() { + let auth = NoAuth; + let client = reqwest::Client::new(); + let mut request = client + .get("http://example.com/test") + .build() + .expect("Failed to build request"); + + let header_count_before = request.headers().len(); + auth.authenticate(&mut request).await.unwrap(); + let header_count_after = request.headers().len(); + + assert_eq!(header_count_before, header_count_after); + } + + #[test] + fn test_no_auth_scheme_name() { + let auth = NoAuth; + assert_eq!(auth.scheme_name(), "none"); + } + + #[tokio::test] + async fn test_no_auth_invalidate() { + let auth = NoAuth; + auth.invalidate().await.unwrap(); + } + + #[tokio::test] + async fn test_no_auth_regenerate() { + let auth = NoAuth; + auth.regenerate().await.unwrap(); + } +} diff --git a/crates/catalog/rest/src/auth/oauth2.rs b/crates/catalog/rest/src/auth/oauth2.rs new file mode 100644 index 0000000000..45af4c845b --- /dev/null +++ b/crates/catalog/rest/src/auth/oauth2.rs @@ -0,0 +1,377 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! OAuth2 authentication for the REST catalog. + +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; + +use async_trait::async_trait; +use http::StatusCode; +use iceberg::{Error, ErrorKind, Result}; +use reqwest::{Client, Method, Request}; +use tokio::sync::Mutex; + +use super::Authenticator; +use crate::types::{ErrorResponse, TokenResponse}; + +/// OAuth2 authenticator for REST catalog. +/// +/// This authenticator supports two modes: +/// 1. **Token mode**: Use a pre-provided bearer token directly +/// 2. **Credential mode**: Exchange client credentials for a token using OAuth2 flow +/// +/// When both token and credentials are provided, the token takes precedence. +pub struct OAuth2Authenticator { + client: Client, + /// Cached bearer token + token: Mutex>, + /// OAuth2 token endpoint URL + token_endpoint: String, + /// Client credentials: (client_id, client_secret) + credential: Option<(Option, String)>, + /// Additional OAuth2 parameters (scope, audience, resource) + extra_params: HashMap, +} + +impl Debug for OAuth2Authenticator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OAuth2Authenticator") + .field("token_endpoint", &self.token_endpoint) + .field("has_credential", &self.credential.is_some()) + .field("extra_params", &self.extra_params) + .finish_non_exhaustive() + } +} + +impl OAuth2Authenticator { + /// Creates a new OAuth2 authenticator with a pre-provided token. + /// + /// # Arguments + /// + /// * `token` - The bearer token to use for authentication + pub fn with_token(token: impl Into) -> Self { + Self { + client: Client::new(), + token: Mutex::new(Some(token.into())), + token_endpoint: String::new(), + credential: None, + extra_params: HashMap::new(), + } + } + + /// Creates a new OAuth2 authenticator with client credentials. + /// + /// The authenticator will exchange these credentials for a bearer token + /// using the OAuth2 client credentials flow. + /// + /// # Arguments + /// + /// * `client_id` - The OAuth2 client ID (optional) + /// * `client_secret` - The OAuth2 client secret + /// * `token_endpoint` - The URL of the OAuth2 token endpoint + pub fn with_credentials( + client_id: Option>, + client_secret: impl Into, + token_endpoint: impl Into, + ) -> Self { + Self { + client: Client::new(), + token: Mutex::new(None), + token_endpoint: token_endpoint.into(), + credential: Some((client_id.map(Into::into), client_secret.into())), + extra_params: HashMap::new(), + } + } + + /// Creates a new OAuth2 authenticator with both token and credentials. + /// + /// The token will be used initially, and credentials will be used to + /// refresh the token when needed. + /// + /// # Arguments + /// + /// * `token` - Initial bearer token (optional) + /// * `client_id` - The OAuth2 client ID (optional) + /// * `client_secret` - The OAuth2 client secret + /// * `token_endpoint` - The URL of the OAuth2 token endpoint + pub fn new( + token: Option>, + client_id: Option>, + client_secret: impl Into, + token_endpoint: impl Into, + ) -> Self { + Self { + client: Client::new(), + token: Mutex::new(token.map(Into::into)), + token_endpoint: token_endpoint.into(), + credential: Some((client_id.map(Into::into), client_secret.into())), + extra_params: HashMap::new(), + } + } + + /// Sets additional OAuth2 parameters to include in token requests. + /// + /// Common parameters include: + /// - `scope`: The OAuth2 scope (default: "catalog") + /// - `audience`: The intended audience for the token + /// - `resource`: The resource being accessed + pub fn with_extra_params(mut self, params: HashMap) -> Self { + self.extra_params = params; + self + } + + /// Sets a custom HTTP client for token requests. + pub fn with_client(mut self, client: Client) -> Self { + self.client = client; + self + } + + /// Exchange credentials for a new token. + async fn exchange_credential_for_token(&self) -> Result { + let (client_id, client_secret) = self.credential.as_ref().ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Credential must be provided for authentication", + ) + })?; + + let mut params = HashMap::with_capacity(4); + params.insert("grant_type", "client_credentials"); + if let Some(client_id) = client_id { + params.insert("client_id", client_id); + } + params.insert("client_secret", client_secret); + params.extend( + self.extra_params + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())), + ); + + let mut auth_req = self + .client + .request(Method::POST, &self.token_endpoint) + .form(¶ms) + .build()?; + + // Ensure correct content-type for form data + auth_req.headers_mut().insert( + http::header::CONTENT_TYPE, + http::HeaderValue::from_static("application/x-www-form-urlencoded"), + ); + + let auth_url = auth_req.url().clone(); + let auth_resp = self.client.execute(auth_req).await?; + + if auth_resp.status() == StatusCode::OK { + let text = auth_resp + .bytes() + .await + .map_err(|err| err.with_url(auth_url.clone()))?; + + let token_response: TokenResponse = serde_json::from_slice(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("operation", "auth") + .with_context("url", auth_url.to_string()) + .with_context("json", String::from_utf8_lossy(&text)) + .with_source(e) + })?; + + Ok(token_response.access_token) + } else { + let code = auth_resp.status(); + let text = auth_resp + .bytes() + .await + .map_err(|err| err.with_url(auth_url.clone()))?; + + let error_response: ErrorResponse = serde_json::from_slice(&text).map_err(|e| { + Error::new(ErrorKind::Unexpected, "Received unexpected response") + .with_context("code", code.to_string()) + .with_context("operation", "auth") + .with_context("url", auth_url.to_string()) + .with_context("json", String::from_utf8_lossy(&text)) + .with_source(e) + })?; + + Err(Error::from(error_response)) + } + } +} + +#[async_trait] +impl Authenticator for OAuth2Authenticator { + async fn authenticate(&self, request: &mut Request) -> Result<()> { + // Clone the token from lock without holding the lock for entire function + let token = self.token.lock().await.clone(); + + if self.credential.is_none() && token.is_none() { + return Ok(()); + } + + // Either use the provided token or exchange credential for token + let token = match token { + Some(token) => token, + None => { + let new_token = self.exchange_credential_for_token().await?; + *self.token.lock().await = Some(new_token.clone()); + new_token + } + }; + + // Insert bearer token in request + request.headers_mut().insert( + http::header::AUTHORIZATION, + format!("Bearer {token}").parse().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Invalid token received from catalog server!", + ) + .with_source(e) + })?, + ); + + Ok(()) + } + + async fn invalidate(&self) -> Result<()> { + *self.token.lock().await = None; + Ok(()) + } + + async fn regenerate(&self) -> Result<()> { + if self.credential.is_some() { + let new_token = self.exchange_credential_for_token().await?; + *self.token.lock().await = Some(new_token); + } + Ok(()) + } + + fn scheme_name(&self) -> &'static str { + "oauth2" + } + + async fn get_token(&self) -> Option { + self.token.lock().await.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_oauth2_with_token() { + let auth = OAuth2Authenticator::with_token("test-token"); + assert_eq!(auth.scheme_name(), "oauth2"); + } + + #[test] + fn test_oauth2_with_credentials() { + let auth = OAuth2Authenticator::with_credentials( + Some("client_id"), + "client_secret", + "https://auth.example.com/token", + ); + assert!(auth.credential.is_some()); + assert_eq!(auth.token_endpoint, "https://auth.example.com/token"); + } + + #[test] + fn test_oauth2_debug_does_not_leak_secrets() { + let auth = OAuth2Authenticator::with_credentials( + Some("client_id"), + "super-secret", + "https://auth.example.com/token", + ); + let debug_str = format!("{auth:?}"); + assert!(!debug_str.contains("super-secret")); + } + + #[tokio::test] + async fn test_oauth2_with_token_adds_bearer_header() { + let auth = OAuth2Authenticator::with_token("test-token-123"); + let client = reqwest::Client::new(); + let mut request = client + .get("http://example.com/test") + .build() + .expect("Failed to build request"); + + auth.authenticate(&mut request).await.unwrap(); + + let auth_header = request.headers().get(http::header::AUTHORIZATION); + assert!(auth_header.is_some()); + assert_eq!(auth_header.unwrap(), "Bearer test-token-123"); + } + + #[tokio::test] + async fn test_oauth2_invalidate_clears_token() { + let auth = OAuth2Authenticator::with_token("test-token"); + + // Token should be set initially + assert!(auth.token.lock().await.is_some()); + + // Invalidate should clear the token + auth.invalidate().await.unwrap(); + assert!(auth.token.lock().await.is_none()); + } + + #[tokio::test] + async fn test_oauth2_no_credentials_no_token_skips_auth() { + let auth = OAuth2Authenticator { + client: Client::new(), + token: Mutex::new(None), + token_endpoint: String::new(), + credential: None, + extra_params: HashMap::new(), + }; + + let client = reqwest::Client::new(); + let mut request = client + .get("http://example.com/test") + .build() + .expect("Failed to build request"); + + auth.authenticate(&mut request).await.unwrap(); + + // No Authorization header should be added + assert!(request.headers().get(http::header::AUTHORIZATION).is_none()); + } + + #[test] + fn test_oauth2_with_extra_params() { + let mut params = HashMap::new(); + params.insert("scope".to_string(), "custom-scope".to_string()); + params.insert("audience".to_string(), "custom-audience".to_string()); + + let auth = OAuth2Authenticator::with_credentials( + Some("client_id"), + "client_secret", + "https://auth.example.com/token", + ) + .with_extra_params(params); + + assert_eq!(auth.extra_params.get("scope").unwrap(), "custom-scope"); + assert_eq!( + auth.extra_params.get("audience").unwrap(), + "custom-audience" + ); + } +} diff --git a/crates/catalog/rest/src/auth/sigv4.rs b/crates/catalog/rest/src/auth/sigv4.rs new file mode 100644 index 0000000000..a448cb240c --- /dev/null +++ b/crates/catalog/rest/src/auth/sigv4.rs @@ -0,0 +1,706 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! AWS Signature Version 4 authentication for the REST catalog. +//! +//! This module provides SigV4 authentication for AWS-compatible services, +//! enabling the REST catalog to authenticate against MinIO, AWS S3 Tables, +//! and other AWS-compatible services. + +use std::fmt::{Debug, Formatter}; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use iceberg::{Error, ErrorKind, Result}; +use reqwest::Request; +use tokio::sync::RwLock; + +use super::Authenticator; + +/// AWS credentials for SigV4 authentication. +#[derive(Clone)] +pub struct SigV4Credentials { + /// AWS access key ID + pub access_key_id: String, + /// AWS secret access key + pub secret_access_key: String, + /// Optional session token for temporary credentials + pub session_token: Option, +} + +impl Debug for SigV4Credentials { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SigV4Credentials") + .field("access_key_id", &self.access_key_id) + .field("secret_access_key", &"[REDACTED]") + .field( + "session_token", + &self.session_token.as_ref().map(|_| "[REDACTED]"), + ) + .finish() + } +} + +impl SigV4Credentials { + /// Creates new SigV4 credentials. + /// + /// # Arguments + /// + /// * `access_key_id` - AWS access key ID + /// * `secret_access_key` - AWS secret access key + /// * `session_token` - Optional session token for temporary credentials + pub fn new( + access_key_id: impl Into, + secret_access_key: impl Into, + session_token: Option>, + ) -> Self { + Self { + access_key_id: access_key_id.into(), + secret_access_key: secret_access_key.into(), + session_token: session_token.map(Into::into), + } + } +} + +/// Cached signing key for performance optimization. +/// +/// AWS SigV4 derives a signing key from the secret key, date, region, and service. +/// This key is valid for an entire day, so caching it significantly improves performance. +struct SigningKeyCache { + key: [u8; 32], + date_stamp: String, + region: String, + service: String, +} + +impl SigningKeyCache { + /// Check if the cached key is valid for the given parameters. + fn is_valid(&self, date_stamp: &str, region: &str, service: &str) -> bool { + self.date_stamp == date_stamp && self.region == region && self.service == service + } +} + +/// AWS Signature Version 4 authenticator. +/// +/// This authenticator signs HTTP requests using AWS Signature Version 4, +/// which is required for AWS services and AWS-compatible services like MinIO. +/// +/// # Example +/// +/// ```ignore +/// use iceberg_catalog_rest::auth::{SigV4Authenticator, SigV4Credentials}; +/// +/// let credentials = SigV4Credentials::new( +/// "AKIAIOSFODNN7EXAMPLE", +/// "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", +/// None, +/// ); +/// +/// let auth = SigV4Authenticator::new(credentials, "us-east-1", "s3"); +/// ``` +pub struct SigV4Authenticator { + credentials: SigV4Credentials, + region: String, + service: String, + /// Cached signing key for performance + signing_key_cache: RwLock>, +} + +impl Debug for SigV4Authenticator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SigV4Authenticator") + .field("credentials", &self.credentials) + .field("region", &self.region) + .field("service", &self.service) + .finish() + } +} + +impl SigV4Authenticator { + /// Creates a new SigV4 authenticator. + /// + /// # Arguments + /// + /// * `credentials` - AWS credentials + /// * `region` - AWS region (e.g., "us-east-1") + /// * `service` - AWS service name (e.g., "s3", "s3tables") + pub fn new( + credentials: SigV4Credentials, + region: impl Into, + service: impl Into, + ) -> Self { + Self { + credentials, + region: region.into(), + service: service.into(), + signing_key_cache: RwLock::new(None), + } + } + + /// Creates a SigV4 authenticator for S3 Tables service. + pub fn for_s3tables(credentials: SigV4Credentials, region: impl Into) -> Self { + Self::new(credentials, region, "s3tables") + } + + /// Creates a SigV4 authenticator for S3 service. + pub fn for_s3(credentials: SigV4Credentials, region: impl Into) -> Self { + Self::new(credentials, region, "s3") + } + + /// Get or compute the signing key. + async fn get_signing_key(&self, date_stamp: &str) -> [u8; 32] { + // Check cache first + { + let cache = self.signing_key_cache.read().await; + if let Some(ref cached) = *cache + && cached.is_valid(date_stamp, &self.region, &self.service) + { + return cached.key; + } + } + + // Compute new signing key + let key = derive_signing_key( + &self.credentials.secret_access_key, + date_stamp, + &self.region, + &self.service, + ); + + // Cache the new key + { + let mut cache = self.signing_key_cache.write().await; + *cache = Some(SigningKeyCache { + key, + date_stamp: date_stamp.to_string(), + region: self.region.clone(), + service: self.service.clone(), + }); + } + + key + } + + /// Sign the request with AWS SigV4. + async fn sign_request(&self, request: &mut Request, now: DateTime) -> Result<()> { + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + let date_stamp = now.format("%Y%m%d").to_string(); + + // Get host from URL + let host = request + .url() + .host_str() + .ok_or_else(|| Error::new(ErrorKind::DataInvalid, "Request URL has no host"))? + .to_string(); + + // Add port if non-standard + let host_header = if let Some(port) = request.url().port() { + format!("{host}:{port}") + } else { + host + }; + + // Add required headers + request + .headers_mut() + .insert("x-amz-date", amz_date.parse().unwrap()); + request + .headers_mut() + .insert("host", host_header.parse().unwrap()); + + // Add session token if present + if let Some(ref token) = self.credentials.session_token { + request.headers_mut().insert( + "x-amz-security-token", + token.parse().map_err(|e| { + Error::new(ErrorKind::DataInvalid, "Invalid session token").with_source(e) + })?, + ); + } + + // Compute payload hash + let payload_hash = if let Some(body) = request.body() { + if let Some(bytes) = body.as_bytes() { + sha256_hex(bytes) + } else { + sha256_hex(b"") + } + } else { + sha256_hex(b"") + }; + + request + .headers_mut() + .insert("x-amz-content-sha256", payload_hash.parse().unwrap()); + + // Build canonical request + let (canonical_request, signed_headers) = build_canonical_request(request, &payload_hash)?; + + // Create string to sign + let credential_scope = + format!("{date_stamp}/{}/{}/aws4_request", self.region, self.service); + let string_to_sign = format!( + "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}", + sha256_hex(canonical_request.as_bytes()) + ); + + // Get signing key and compute signature + let signing_key = self.get_signing_key(&date_stamp).await; + let signature = hmac_sha256_hex(&signing_key, string_to_sign.as_bytes()); + + // Build Authorization header + let authorization = format!( + "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + self.credentials.access_key_id, credential_scope, signed_headers, signature + ); + + request.headers_mut().insert( + http::header::AUTHORIZATION, + authorization.parse().map_err(|e| { + Error::new(ErrorKind::DataInvalid, "Invalid authorization header").with_source(e) + })?, + ); + + Ok(()) + } +} + +#[async_trait] +impl Authenticator for SigV4Authenticator { + async fn authenticate(&self, request: &mut Request) -> Result<()> { + self.sign_request(request, Utc::now()).await + } + + async fn invalidate(&self) -> Result<()> { + // Clear the signing key cache + let mut cache = self.signing_key_cache.write().await; + *cache = None; + Ok(()) + } + + fn scheme_name(&self) -> &'static str { + "sigv4" + } +} + +// ============================================================================= +// Cryptographic helper functions +// ============================================================================= + +/// Compute SHA-256 hash and return as lowercase hex string. +fn sha256_hex(data: &[u8]) -> String { + use std::fmt::Write; + + let mut hasher = Sha256::new(); + hasher.update(data); + let result = hasher.finalize(); + + let mut hex = String::with_capacity(64); + for byte in result { + write!(&mut hex, "{byte:02x}").unwrap(); + } + hex +} + +/// Compute HMAC-SHA256 and return as bytes. +fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] { + use hmac::{Hmac, Mac}; + type HmacSha256 = Hmac; + + let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size"); + mac.update(data); + mac.finalize().into_bytes().into() +} + +/// Compute HMAC-SHA256 and return as lowercase hex string. +fn hmac_sha256_hex(key: &[u8], data: &[u8]) -> String { + use std::fmt::Write; + + let result = hmac_sha256(key, data); + let mut hex = String::with_capacity(64); + for byte in result { + write!(&mut hex, "{byte:02x}").unwrap(); + } + hex +} + +/// Derive the signing key for AWS SigV4. +fn derive_signing_key(secret_key: &str, date_stamp: &str, region: &str, service: &str) -> [u8; 32] { + let k_date = hmac_sha256( + format!("AWS4{secret_key}").as_bytes(), + date_stamp.as_bytes(), + ); + let k_region = hmac_sha256(&k_date, region.as_bytes()); + let k_service = hmac_sha256(&k_region, service.as_bytes()); + hmac_sha256(&k_service, b"aws4_request") +} + +/// URL-encode a string for use in canonical requests. +fn uri_encode(input: &str, encode_slash: bool) -> String { + let mut encoded = String::with_capacity(input.len() * 3); + for byte in input.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + encoded.push(byte as char); + } + b'/' if !encode_slash => { + encoded.push('/'); + } + _ => { + encoded.push_str(&format!("%{byte:02X}")); + } + } + } + encoded +} + +/// Build the canonical request string and return (canonical_request, signed_headers). +fn build_canonical_request(request: &Request, payload_hash: &str) -> Result<(String, String)> { + let method = request.method().as_str(); + + // Canonical URI (path) + let path = request.url().path(); + let canonical_uri = if path.is_empty() { "/" } else { path }; + let canonical_uri = uri_encode(canonical_uri, false); + + // Canonical query string + let canonical_query_string = build_canonical_query_string(request.url().query()); + + // Canonical headers and signed headers + let (canonical_headers, signed_headers) = build_canonical_headers(request)?; + + let canonical_request = format!( + "{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" + ); + + Ok((canonical_request, signed_headers)) +} + +/// Build canonical query string from URL query parameters. +fn build_canonical_query_string(query: Option<&str>) -> String { + let Some(query) = query else { + return String::new(); + }; + + let mut params: Vec<(String, String)> = query + .split('&') + .filter(|s| !s.is_empty()) + .map(|param| { + let mut parts = param.splitn(2, '='); + let key = parts.next().unwrap_or(""); + let value = parts.next().unwrap_or(""); + (uri_encode(key, true), uri_encode(value, true)) + }) + .collect(); + + params.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1))); + + params + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&") +} + +/// Build canonical headers and signed headers list. +fn build_canonical_headers(request: &Request) -> Result<(String, String)> { + let mut headers: Vec<(String, String)> = request + .headers() + .iter() + .map(|(name, value)| { + let name = name.as_str().to_lowercase(); + let value = value.to_str().unwrap_or("").trim().to_string(); + (name, value) + }) + .collect(); + + // Sort by header name + headers.sort_by(|a, b| a.0.cmp(&b.0)); + + // Build canonical headers string + let canonical_headers: String = headers + .iter() + .map(|(name, value)| format!("{name}:{value}\n")) + .collect(); + + // Build signed headers list + let signed_headers: String = headers + .iter() + .map(|(name, _)| name.as_str()) + .collect::>() + .join(";"); + + Ok((canonical_headers, signed_headers)) +} + +// Import SHA-256 from the sha2 crate +use sha2::{Digest, Sha256}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_credentials_debug_redacts_secrets() { + let creds = SigV4Credentials::new( + "AKIAIOSFODNN7EXAMPLE", + "my-secret-key", + Some("my-token-value"), + ); + let debug_str = format!("{creds:?}"); + assert!(debug_str.contains("AKIAIOSFODNN7EXAMPLE")); + assert!(debug_str.contains("[REDACTED]")); + // Verify actual secrets are not in the output + assert!(!debug_str.contains("my-secret-key")); + assert!(!debug_str.contains("my-token-value")); + } + + #[test] + fn test_sha256_hex() { + // Test vector from AWS documentation + let hash = sha256_hex(b""); + assert_eq!( + hash, + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + + let hash = sha256_hex(b"hello"); + assert_eq!( + hash, + "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" + ); + } + + #[test] + fn test_derive_signing_key() { + // AWS SigV4 test vector + let key = derive_signing_key( + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "20150830", + "us-east-1", + "iam", + ); + + // The signing key should be 32 bytes + assert_eq!(key.len(), 32); + + // Verify against known test vector (from AWS docs) + let expected_hex = "c4afb1cc5771d871763a393e44b703571b55cc28424d1a5e86da6ed3c154a4b9"; + let actual_hex = key.iter().map(|b| format!("{b:02x}")).collect::(); + assert_eq!(actual_hex, expected_hex); + } + + #[test] + fn test_uri_encode() { + assert_eq!(uri_encode("hello", true), "hello"); + assert_eq!(uri_encode("hello world", true), "hello%20world"); + assert_eq!(uri_encode("hello/world", true), "hello%2Fworld"); + assert_eq!(uri_encode("hello/world", false), "hello/world"); + assert_eq!(uri_encode("a=b&c=d", true), "a%3Db%26c%3Dd"); + } + + #[test] + fn test_canonical_query_string() { + // Empty query + assert_eq!(build_canonical_query_string(None), ""); + assert_eq!(build_canonical_query_string(Some("")), ""); + + // Single parameter + assert_eq!(build_canonical_query_string(Some("foo=bar")), "foo=bar"); + + // Multiple parameters (should be sorted) + assert_eq!(build_canonical_query_string(Some("b=2&a=1")), "a=1&b=2"); + + // Parameters with encoding needed + assert_eq!( + build_canonical_query_string(Some("foo=bar baz")), + "foo=bar%20baz" + ); + } + + #[tokio::test] + async fn test_sigv4_authenticator_adds_headers() { + let creds = SigV4Credentials::new("AKIAIOSFODNN7EXAMPLE", "secret-key", None::); + let auth = SigV4Authenticator::new(creds, "us-east-1", "s3"); + + let client = reqwest::Client::new(); + let mut request = client + .get("http://example.com/test") + .build() + .expect("Failed to build request"); + + auth.authenticate(&mut request).await.unwrap(); + + // Check that required headers were added + assert!(request.headers().contains_key("x-amz-date")); + assert!(request.headers().contains_key("x-amz-content-sha256")); + assert!(request.headers().contains_key("host")); + assert!(request.headers().contains_key(http::header::AUTHORIZATION)); + + // Check Authorization header format + let auth_header = request + .headers() + .get(http::header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(); + assert!(auth_header.starts_with("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/")); + assert!(auth_header.contains("SignedHeaders=")); + assert!(auth_header.contains("Signature=")); + } + + #[tokio::test] + async fn test_sigv4_with_session_token() { + let creds = + SigV4Credentials::new("AKIAIOSFODNN7EXAMPLE", "secret-key", Some("session-token")); + let auth = SigV4Authenticator::new(creds, "us-east-1", "s3"); + + let client = reqwest::Client::new(); + let mut request = client + .get("http://example.com/test") + .build() + .expect("Failed to build request"); + + auth.authenticate(&mut request).await.unwrap(); + + // Check that session token header was added + let token_header = request.headers().get("x-amz-security-token"); + assert!(token_header.is_some()); + assert_eq!(token_header.unwrap().to_str().unwrap(), "session-token"); + } + + #[tokio::test] + async fn test_sigv4_signing_key_caching() { + let creds = SigV4Credentials::new("AKIAIOSFODNN7EXAMPLE", "secret-key", None::); + let auth = SigV4Authenticator::new(creds, "us-east-1", "s3"); + + // Cache should be empty initially + assert!(auth.signing_key_cache.read().await.is_none()); + + let client = reqwest::Client::new(); + let mut request = client + .get("http://example.com/test") + .build() + .expect("Failed to build request"); + + auth.authenticate(&mut request).await.unwrap(); + + // Cache should be populated after signing + assert!(auth.signing_key_cache.read().await.is_some()); + } + + #[tokio::test] + async fn test_sigv4_invalidate_clears_cache() { + let creds = SigV4Credentials::new("AKIAIOSFODNN7EXAMPLE", "secret-key", None::); + let auth = SigV4Authenticator::new(creds, "us-east-1", "s3"); + + let client = reqwest::Client::new(); + let mut request = client + .get("http://example.com/test") + .build() + .expect("Failed to build request"); + + auth.authenticate(&mut request).await.unwrap(); + assert!(auth.signing_key_cache.read().await.is_some()); + + auth.invalidate().await.unwrap(); + assert!(auth.signing_key_cache.read().await.is_none()); + } + + #[test] + fn test_sigv4_scheme_name() { + let creds = SigV4Credentials::new("key", "secret", None::); + let auth = SigV4Authenticator::new(creds, "us-east-1", "s3"); + assert_eq!(auth.scheme_name(), "sigv4"); + } + + #[test] + fn test_sigv4_for_s3tables() { + let creds = SigV4Credentials::new("key", "secret", None::); + let auth = SigV4Authenticator::for_s3tables(creds, "us-east-1"); + assert_eq!(auth.service, "s3tables"); + } + + #[test] + fn test_sigv4_for_s3() { + let creds = SigV4Credentials::new("key", "secret", None::); + let auth = SigV4Authenticator::for_s3(creds, "us-east-1"); + assert_eq!(auth.service, "s3"); + } + + #[tokio::test] + async fn test_sigv4_with_query_parameters() { + let creds = SigV4Credentials::new("AKIAIOSFODNN7EXAMPLE", "secret-key", None::); + let auth = SigV4Authenticator::new(creds, "us-east-1", "s3"); + + let client = reqwest::Client::new(); + let mut request = client + .get("http://example.com/test?foo=bar&baz=qux") + .build() + .expect("Failed to build request"); + + // Should not panic and should add auth headers + auth.authenticate(&mut request).await.unwrap(); + assert!(request.headers().contains_key(http::header::AUTHORIZATION)); + } + + #[tokio::test] + async fn test_sigv4_with_post_body() { + let creds = SigV4Credentials::new("AKIAIOSFODNN7EXAMPLE", "secret-key", None::); + let auth = SigV4Authenticator::new(creds, "us-east-1", "s3"); + + let client = reqwest::Client::new(); + let mut request = client + .post("http://example.com/test") + .body(r#"{"key": "value"}"#) + .build() + .expect("Failed to build request"); + + // Should hash the body and include in signature + auth.authenticate(&mut request).await.unwrap(); + + let content_hash = request + .headers() + .get("x-amz-content-sha256") + .unwrap() + .to_str() + .unwrap(); + // The hash should not be the empty string hash + assert_ne!( + content_hash, + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + } + + #[tokio::test] + async fn test_sigv4_with_custom_port() { + let creds = SigV4Credentials::new("AKIAIOSFODNN7EXAMPLE", "secret-key", None::); + let auth = SigV4Authenticator::new(creds, "us-east-1", "s3"); + + let client = reqwest::Client::new(); + let mut request = client + .get("http://localhost:9000/test") + .build() + .expect("Failed to build request"); + + auth.authenticate(&mut request).await.unwrap(); + + // Host header should include port + let host_header = request.headers().get("host").unwrap().to_str().unwrap(); + assert_eq!(host_header, "localhost:9000"); + } +} diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index eeea1f13e9..3c2da3bafe 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -306,6 +306,55 @@ impl RestCatalogConfig { .unwrap_or(false) } + /// Get the authentication type from config. + /// + /// Returns the value of the `auth-type` property, if set. + /// Supported values: "oauth2" (default), "sigv4" + pub(crate) fn auth_type(&self) -> Option<&str> { + self.props.get("auth-type").map(|s| s.as_str()) + } + + /// Create a SigV4 authenticator from config properties. + /// + /// Required properties: + /// - `s3.access-key-id`: AWS access key ID + /// - `s3.secret-access-key`: AWS secret access key + /// + /// Optional properties: + /// - `s3.session-token`: AWS session token for temporary credentials + /// - `s3.region` or `region`: AWS region (defaults to "us-east-1") + /// - `sigv4.service`: Service name (defaults to "s3") + pub(crate) fn create_sigv4_authenticator(&self) -> Option { + let access_key = self.props.get("s3.access-key-id")?; + let secret_key = self.props.get("s3.secret-access-key")?; + let session_token = self.props.get("s3.session-token").cloned(); + + let region = self + .props + .get("s3.region") + .or_else(|| self.props.get("region")) + .map(|s| s.as_str()) + .unwrap_or("us-east-1"); + + let service = self + .props + .get("sigv4.service") + .map(|s| s.as_str()) + .unwrap_or("s3"); + + let credentials = crate::auth::SigV4Credentials::new( + access_key.clone(), + secret_key.clone(), + session_token, + ); + + Some(crate::auth::SigV4Authenticator::new( + credentials, + region, + service, + )) + } + /// Merge the `RestCatalogConfig` with the a [`CatalogConfig`] (fetched from the REST server). pub(crate) fn merge_with_config(mut self, mut config: CatalogConfig) -> Self { if let Some(uri) = config.overrides.remove("uri") { diff --git a/crates/catalog/rest/src/client.rs b/crates/catalog/rest/src/client.rs index 07dc0620da..ba902987fd 100644 --- a/crates/catalog/rest/src/client.rs +++ b/crates/catalog/rest/src/client.rs @@ -17,28 +17,20 @@ use std::collections::HashMap; use std::fmt::{Debug, Formatter}; +use std::sync::Arc; -use http::StatusCode; -use iceberg::{Error, ErrorKind, Result}; +use iceberg::Result; use reqwest::header::HeaderMap; use reqwest::{Client, IntoUrl, Method, Request, RequestBuilder, Response}; use serde::de::DeserializeOwned; -use tokio::sync::Mutex; use crate::RestCatalogConfig; -use crate::types::{ErrorResponse, TokenResponse}; +use crate::auth::{Authenticator, NoAuth, OAuth2Authenticator}; pub(crate) struct HttpClient { client: Client, - - /// The token to be used for authentication. - /// - /// It's possible to fetch the token from the server while needed. - token: Mutex>, - /// The token endpoint to be used for authentication. - token_endpoint: String, - /// The credential to be used for authentication. - credential: Option<(Option, String)>, + /// The authenticator to use for requests + authenticator: Arc, /// Extra headers to be added to each request. extra_headers: HeaderMap, /// Extra oauth parameters to be added to each authentication request. @@ -51,44 +43,90 @@ impl Debug for HttpClient { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("HttpClient") .field("client", &self.client) + .field("authenticator", &self.authenticator.scheme_name()) .field("extra_headers", &self.extra_headers) .finish_non_exhaustive() } } impl HttpClient { - /// Create a new http client. + /// Create a new http client from configuration. + /// + /// This method automatically creates the appropriate authenticator based on + /// the configuration properties. pub fn new(cfg: &RestCatalogConfig) -> Result { let extra_headers = cfg.extra_headers()?; + let authenticator = Self::create_authenticator(cfg); + Ok(HttpClient { client: cfg.client().unwrap_or_default(), - token: Mutex::new(cfg.token()), - token_endpoint: cfg.get_token_endpoint(), - credential: cfg.credential(), + authenticator, extra_headers, extra_oauth_params: cfg.extra_oauth_params(), disable_header_redaction: cfg.disable_header_redaction(), }) } + /// Create the appropriate authenticator based on configuration. + fn create_authenticator(cfg: &RestCatalogConfig) -> Arc { + // Check for SigV4 authentication + if cfg.auth_type() == Some("sigv4") + && let Some(auth) = cfg.create_sigv4_authenticator() + { + return Arc::new(auth); + } + + // Default to OAuth2 authentication + let token = cfg.token(); + let credential = cfg.credential(); + + if token.is_none() && credential.is_none() { + return Arc::new(NoAuth); + } + + let token_endpoint = cfg.get_token_endpoint(); + let extra_params = cfg.extra_oauth_params(); + + if let Some(token) = token { + if let Some((client_id, client_secret)) = credential { + // Both token and credentials provided + Arc::new( + OAuth2Authenticator::new(Some(token), client_id, client_secret, token_endpoint) + .with_extra_params(extra_params), + ) + } else { + // Token only + Arc::new(OAuth2Authenticator::with_token(token)) + } + } else if let Some((client_id, client_secret)) = credential { + // Credentials only + Arc::new( + OAuth2Authenticator::with_credentials(client_id, client_secret, token_endpoint) + .with_extra_params(extra_params), + ) + } else { + Arc::new(NoAuth) + } + } + /// Update the http client with new configuration. /// /// If cfg carries new value, we will use cfg instead. /// Otherwise, we will keep the old value. + /// + /// Note: The existing authenticator is preserved to maintain cached tokens. + /// The token was already obtained during the initial config fetch. pub fn update_with(self, cfg: &RestCatalogConfig) -> Result { let extra_headers = (!cfg.extra_headers()?.is_empty()) .then(|| cfg.extra_headers()) .transpose()? .unwrap_or(self.extra_headers); + + // Keep the existing authenticator - it has the cached token from the + // initial config fetch. The merged config will have the same credentials. Ok(HttpClient { client: cfg.client().unwrap_or(self.client), - token: Mutex::new(cfg.token().or_else(|| self.token.into_inner())), - token_endpoint: if !cfg.get_token_endpoint().is_empty() { - cfg.get_token_endpoint() - } else { - self.token_endpoint - }, - credential: cfg.credential().or(self.credential), + authenticator: self.authenticator, extra_headers, extra_oauth_params: if !cfg.extra_oauth_params().is_empty() { cfg.extra_oauth_params() @@ -100,89 +138,29 @@ impl HttpClient { } /// This API is testing only to assert the token. + /// + /// Returns the current cached token from the authenticator. + /// For OAuth2, this returns the cached bearer token. + /// For non-token-based auth (SigV4, NoAuth), this returns None. + /// + /// If no token is cached, this will trigger authentication to obtain one. #[cfg(test)] pub(crate) async fn token(&self) -> Option { - let mut req = self - .request(Method::GET, &self.token_endpoint) - .build() - .unwrap(); - self.authenticate(&mut req).await.ok(); - self.token.lock().await.clone() - } - - async fn exchange_credential_for_token(&self) -> Result { - // Credential must exist here. - let (client_id, client_secret) = self.credential.as_ref().ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - "Credential must be provided for authentication", - ) - })?; - - let mut params = HashMap::with_capacity(4); - params.insert("grant_type", "client_credentials"); - if let Some(client_id) = client_id { - params.insert("client_id", client_id); + // If no token is cached, trigger authentication to obtain one + if self.authenticator.get_token().await.is_none() { + // Create a dummy request to trigger authentication + if let Ok(mut req) = self.request(Method::GET, "http://localhost/test").build() { + // Ignore authentication errors, just try to populate the cache + let _ = self.authenticator.authenticate(&mut req).await; + } } - params.insert("client_secret", client_secret); - params.extend( - self.extra_oauth_params - .iter() - .map(|(k, v)| (k.as_str(), v.as_str())), - ); - - let mut auth_req = self - .request(Method::POST, &self.token_endpoint) - .form(¶ms) - .build()?; - // extra headers add content-type application/json header it's necessary to override it with proper type - // note that form call doesn't add content-type header if already present - auth_req.headers_mut().insert( - http::header::CONTENT_TYPE, - http::HeaderValue::from_static("application/x-www-form-urlencoded"), - ); - let auth_url = auth_req.url().clone(); - let auth_resp = self.client.execute(auth_req).await?; - - let auth_res: TokenResponse = if auth_resp.status() == StatusCode::OK { - let text = auth_resp - .bytes() - .await - .map_err(|err| err.with_url(auth_url.clone()))?; - Ok(serde_json::from_slice(&text).map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("operation", "auth") - .with_context("url", auth_url.to_string()) - .with_context("json", String::from_utf8_lossy(&text)) - .with_source(e) - })?) - } else { - let code = auth_resp.status(); - let text = auth_resp - .bytes() - .await - .map_err(|err| err.with_url(auth_url.clone()))?; - let e: ErrorResponse = serde_json::from_slice(&text).map_err(|e| { - Error::new(ErrorKind::Unexpected, "Received unexpected response") - .with_context("code", code.to_string()) - .with_context("operation", "auth") - .with_context("url", auth_url.to_string()) - .with_context("json", String::from_utf8_lossy(&text)) - .with_source(e) - })?; - Err(Error::from(e)) - }?; - Ok(auth_res.access_token) + self.authenticator.get_token().await } /// Invalidate the current token without generating a new one. On the next request, the client /// will attempt to generate a new token. pub(crate) async fn invalidate_token(&self) -> Result<()> { - *self.token.lock().await = None; - Ok(()) + self.authenticator.invalidate().await } /// Invalidate the current token and set a new one. Generates a new token before invalidating @@ -192,55 +170,7 @@ impl HttpClient { /// If credential is invalid, or the request fails, this method will return an error and leave /// the current token unchanged. pub(crate) async fn regenerate_token(&self) -> Result<()> { - let new_token = self.exchange_credential_for_token().await?; - *self.token.lock().await = Some(new_token.clone()); - Ok(()) - } - - /// Authenticates the request by adding a bearer token to the authorization header. - /// - /// This method supports three authentication modes: - /// - /// 1. **No authentication** - Skip authentication when both `credential` and `token` are missing. - /// 2. **Token authentication** - Use the provided `token` directly for authentication. - /// 3. **OAuth authentication** - Exchange `credential` for a token, cache it, then use it for authentication. - /// - /// When both `credential` and `token` are present, `token` takes precedence. - /// - /// # TODO: Support automatic token refreshing. - async fn authenticate(&self, req: &mut Request) -> Result<()> { - // Clone the token from lock without holding the lock for entire function. - let token = self.token.lock().await.clone(); - - if self.credential.is_none() && token.is_none() { - return Ok(()); - } - - // Either use the provided token or exchange credential for token, cache and use that - let token = match token { - Some(token) => token, - None => { - let token = self.exchange_credential_for_token().await?; - // Update token so that we use it for next request instead of - // exchanging credential for token from the server again - *self.token.lock().await = Some(token.clone()); - token - } - }; - - // Insert token in request. - req.headers_mut().insert( - http::header::AUTHORIZATION, - format!("Bearer {token}").parse().map_err(|e| { - Error::new( - ErrorKind::DataInvalid, - "Invalid token received from catalog server!", - ) - .with_source(e) - })?, - ); - - Ok(()) + self.authenticator.regenerate().await } #[inline] @@ -259,7 +189,7 @@ impl HttpClient { // Queries the Iceberg REST catalog after authentication with the given `Request` and // returns a `Response`. pub async fn query_catalog(&self, mut request: Request) -> Result { - self.authenticate(&mut request).await?; + self.authenticator.authenticate(&mut request).await?; self.execute(request).await } @@ -278,8 +208,8 @@ pub(crate) async fn deserialize_catalog_response( let bytes = response.bytes().await?; serde_json::from_slice::(&bytes).map_err(|e| { - Error::new( - ErrorKind::Unexpected, + iceberg::Error::new( + iceberg::ErrorKind::Unexpected, "Failed to parse response from rest catalog server", ) .with_context("json", String::from_utf8_lossy(&bytes)) @@ -335,9 +265,9 @@ fn format_headers_redacted(headers: &HeaderMap, disable_redaction: bool) -> Stri pub(crate) async fn deserialize_unexpected_catalog_error( response: Response, disable_header_redaction: bool, -) -> Error { - let err = Error::new( - ErrorKind::Unexpected, +) -> iceberg::Error { + let err = iceberg::Error::new( + iceberg::ErrorKind::Unexpected, "Received response with unexpected status code", ) .with_context("status", response.status().to_string()) diff --git a/crates/catalog/rest/src/lib.rs b/crates/catalog/rest/src/lib.rs index 6bee950970..1a0527d6f8 100644 --- a/crates/catalog/rest/src/lib.rs +++ b/crates/catalog/rest/src/lib.rs @@ -51,9 +51,11 @@ #![deny(missing_docs)] +pub mod auth; mod catalog; mod client; mod types; +pub use auth::{Authenticator, NoAuth, OAuth2Authenticator, SigV4Authenticator, SigV4Credentials}; pub use catalog::*; pub use types::*;