diff --git a/Cargo.lock b/Cargo.lock index 97ee25d658..d102bebcfe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3482,9 +3482,11 @@ dependencies = [ "mockito", "port_scanner", "reqwest", + "reqwest-middleware", "serde", "serde_derive", "serde_json", + "task-local-extensions", "tokio", "tracing", "typed-builder", @@ -5509,6 +5511,21 @@ dependencies = [ "webpki-roots 1.0.4", ] +[[package]] +name = "reqwest-middleware" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57f17d28a6e6acfe1733fe24bcd30774d13bffa4b8a22535b4c8c98423088d4e" +dependencies = [ + "anyhow", + "async-trait", + "http 1.4.0", + "reqwest", + "serde", + "thiserror 1.0.69", + "tower-service", +] + [[package]] name = "ring" version = "0.17.14" @@ -6688,6 +6705,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "task-local-extensions" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba323866e5d033818e3240feeb9f7db2c4296674e4d9e16b97b7bf8f490434e8" +dependencies = [ + "pin-utils", +] + [[package]] name = "tempfile" version = "3.23.0" diff --git a/crates/catalog/rest/Cargo.toml b/crates/catalog/rest/Cargo.toml index 916b5ccf75..a3430d785d 100644 --- a/crates/catalog/rest/Cargo.toml +++ b/crates/catalog/rest/Cargo.toml @@ -35,6 +35,7 @@ http = { workspace = true } iceberg = { workspace = true } itertools = { workspace = true } reqwest = { workspace = true } +reqwest-middleware = { version = "0.4", optional = true } serde = { workspace = true } serde_derive = { workspace = true } serde_json = { workspace = true } @@ -43,9 +44,14 @@ tracing = { workspace = true } typed-builder = { workspace = true } uuid = { workspace = true, features = ["v4"] } +[features] +middleware = ["reqwest-middleware"] + [dev-dependencies] +async-trait = { workspace = true } ctor = { workspace = true } iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } mockito = { workspace = true } port_scanner = { workspace = true } +task-local-extensions = "0.1" tokio = { workspace = true } diff --git a/crates/catalog/rest/README.md b/crates/catalog/rest/README.md index e3bb70e94b..cbfab2f8c9 100644 --- a/crates/catalog/rest/README.md +++ b/crates/catalog/rest/README.md @@ -25,3 +25,41 @@ This crate contains the official Native Rust implementation of Apache Iceberg Rest Catalog. See the [API documentation](https://docs.rs/iceberg-catalog-rest/latest) for examples and the full API. + +## Features + +### Middleware Support + +The `middleware` feature enables support for custom HTTP middleware using the `reqwest-middleware` crate. This allows you to add custom behavior to HTTP requests, such as: + +- Request/response logging +- Retry logic +- Rate limiting +- Custom authentication +- Metrics collection + +To enable middleware support, add the feature to your `Cargo.toml`: + +```toml +[dependencies] +iceberg-catalog-rest = { version = "0.8", features = ["middleware"] } +reqwest-middleware = "0.4" +``` + +Example usage: + +```rust +use iceberg_catalog_rest::RestCatalogBuilder; +use reqwest_middleware::ClientBuilder; +use reqwest::Client; + +// Create a client with middleware +let client = ClientBuilder::new(Client::new()) + // Add your middleware here + .build(); + +// Configure the catalog with the middleware client +let catalog = RestCatalogBuilder::new("http://localhost:8080") + .with_middleware_client(client) + .build()?; +``` diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index ddbf6a4e01..d5ed603d73 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -67,6 +67,8 @@ impl Default for RestCatalogBuilder { warehouse: None, props: HashMap::new(), client: None, + #[cfg(feature = "middleware")] + middleware_client: None, }) } } @@ -124,6 +126,35 @@ impl RestCatalogBuilder { self.0.client = Some(client); self } + + /// Configures the catalog with a custom HTTP client with middleware. + /// + /// This method allows you to provide a `reqwest_middleware::ClientWithMiddleware` + /// which wraps a `reqwest::Client` and adds middleware functionality. + /// + /// # Example + /// + /// ```rust,ignore + /// use reqwest::Client; + /// use reqwest_middleware::ClientBuilder; + /// use iceberg_catalog_rest::RestCatalogBuilder; + /// + /// let reqwest_client = Client::new(); + /// let client_with_middleware = ClientBuilder::new(reqwest_client) + /// .with(some_middleware) + /// .build(); + /// + /// let catalog_builder = RestCatalogBuilder::default() + /// .with_middleware_client(client_with_middleware); + /// ``` + #[cfg(feature = "middleware")] + pub fn with_middleware_client( + mut self, + client: reqwest_middleware::ClientWithMiddleware, + ) -> Self { + self.0.middleware_client = Some(client); + self + } } /// Rest catalog configuration. @@ -142,6 +173,10 @@ pub(crate) struct RestCatalogConfig { #[builder(default)] client: Option, + + #[cfg(feature = "middleware")] + #[builder(default)] + middleware_client: Option, } impl RestCatalogConfig { @@ -199,6 +234,18 @@ impl RestCatalogConfig { self.client.clone() } + /// Check if a middleware client is configured. + #[cfg(feature = "middleware")] + pub(crate) fn has_middleware_client(&self) -> bool { + self.middleware_client.is_some() + } + + /// Get the middleware client from the config. + #[cfg(feature = "middleware")] + pub(crate) fn middleware_client(&self) -> Option { + self.middleware_client.clone() + } + /// Get the token from the config. /// /// The client can use this token to send requests. @@ -2722,4 +2769,110 @@ mod tests { assert_eq!(err.message(), "Catalog uri is required"); } } + + #[cfg(feature = "middleware")] + #[tokio::test] + async fn test_with_middleware_client() { + use reqwest::Client; + use reqwest_middleware::ClientBuilder; + + let mut server = Server::new_async().await; + + // Mock the config endpoint + let _config_mock = server + .mock("GET", "/v1/config") + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{ + "defaults": {}, + "overrides": {} + }"#, + ) + .create_async() + .await; + + // Create a middleware client + let middleware_client = ClientBuilder::new(Client::new()).build(); + + // Create catalog with middleware client + let _catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .middleware_client(Some(middleware_client)) + .build(), + ); + + // If we got here without panicking, the catalog was created successfully + } + + #[cfg(feature = "middleware")] + #[tokio::test] + async fn test_middleware_intercepts_requests() { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use mockito::Matcher; + use reqwest::Client; + use reqwest_middleware::{ClientBuilder, Middleware, Next}; + + // Custom middleware that counts requests + #[derive(Clone)] + struct CountingMiddleware { + counter: Arc, + } + + #[async_trait::async_trait] + impl Middleware for CountingMiddleware { + async fn handle( + &self, + req: reqwest::Request, + extensions: &mut http::Extensions, + next: Next<'_>, + ) -> reqwest_middleware::Result { + self.counter.fetch_add(1, Ordering::SeqCst); + next.run(req, extensions).await + } + } + + let mut server = Server::new_async().await; + let counter = Arc::new(AtomicUsize::new(0)); + + // Mock the config endpoint + let config_mock = server + .mock("GET", "/v1/config") + .match_header("user-agent", Matcher::Any) + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{ + "defaults": {}, + "overrides": {} + }"#, + ) + .create_async() + .await; + + // Create middleware client with counting middleware + let middleware_client = ClientBuilder::new(Client::new()) + .with(CountingMiddleware { + counter: counter.clone(), + }) + .build(); + + // Create catalog with middleware client + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .middleware_client(Some(middleware_client)) + .build(), + ); + + // Make a request to trigger the middleware + let _ = catalog.context().await; + + // Verify the middleware intercepted the request + config_mock.assert(); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } } diff --git a/crates/catalog/rest/src/client.rs b/crates/catalog/rest/src/client.rs index 361c036bb6..6c2a728373 100644 --- a/crates/catalog/rest/src/client.rs +++ b/crates/catalog/rest/src/client.rs @@ -30,6 +30,8 @@ use crate::types::{ErrorResponse, TokenResponse}; pub(crate) struct HttpClient { client: Client, + #[cfg(feature = "middleware")] + middleware_client: Option, /// The token to be used for authentication. /// @@ -60,6 +62,8 @@ impl HttpClient { let extra_headers = cfg.extra_headers()?; Ok(HttpClient { client: cfg.client().unwrap_or_default(), + #[cfg(feature = "middleware")] + middleware_client: cfg.middleware_client(), token: Mutex::new(cfg.token()), token_endpoint: cfg.get_token_endpoint(), credential: cfg.credential(), @@ -77,8 +81,20 @@ impl HttpClient { .then(|| cfg.extra_headers()) .transpose()? .unwrap_or(self.extra_headers); + + let client = cfg.client().unwrap_or(self.client); + + #[cfg(feature = "middleware")] + let middleware_client = if cfg.has_middleware_client() { + cfg.middleware_client() + } else { + self.middleware_client + }; + Ok(HttpClient { - client: cfg.client().unwrap_or(self.client), + client, + #[cfg(feature = "middleware")] + middleware_client, token: Mutex::new(cfg.token().or_else(|| self.token.into_inner())), token_endpoint: if !cfg.get_token_endpoint().is_empty() { cfg.get_token_endpoint() @@ -249,6 +265,18 @@ impl HttpClient { /// Executes the given `Request` and returns a `Response`. pub async fn execute(&self, mut request: Request) -> Result { request.headers_mut().extend(self.extra_headers.clone()); + + #[cfg(feature = "middleware")] + if let Some(ref middleware_client) = self.middleware_client { + return middleware_client.execute(request).await.map_err(|e| { + Error::new( + ErrorKind::Unexpected, + format!("Failed to execute request: {}", e), + ) + .with_source(e) + }); + } + Ok(self.client.execute(request).await?) }