diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index ddb718c9..0cc5e860 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -9,7 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added -- None. +- Add `ServiceExt` trait ([#410]) + +[#410]: https://github.com/tower-rs/tower-http/pull/410 ## Changed diff --git a/tower-http/src/add_extension.rs b/tower-http/src/add_extension.rs index 095646df..67e176dd 100644 --- a/tower-http/src/add_extension.rs +++ b/tower-http/src/add_extension.rs @@ -138,10 +138,10 @@ where mod tests { #[allow(unused_imports)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::Response; use std::{convert::Infallible, sync::Arc}; - use tower::{service_fn, ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; struct State(i32); @@ -164,4 +164,23 @@ mod tests { assert_eq!(1, res); } + + #[tokio::test] + async fn basic_service_ext() { + let state = Arc::new(State(1)); + + let svc = service_fn(|req: Request| async move { + let state = req.extensions().get::>().unwrap(); + Ok::<_, Infallible>(Response::new(state.0)) + }) + .add_extension(state); + + let res = svc + .oneshot(Request::new(Body::empty())) + .await + .unwrap() + .into_body(); + + assert_eq!(1, res); + } } diff --git a/tower-http/src/auth/add_authorization.rs b/tower-http/src/auth/add_authorization.rs index 246c13b6..69ab459b 100644 --- a/tower-http/src/auth/add_authorization.rs +++ b/tower-http/src/auth/add_authorization.rs @@ -189,11 +189,11 @@ where #[cfg(test)] mod tests { use super::*; - use crate::test_helpers::Body; use crate::validate_request::ValidateRequestHeaderLayer; + use crate::{test_helpers::Body, ServiceExt}; use http::{Response, StatusCode}; use std::convert::Infallible; - use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; + use tower::{service_fn, BoxError, Service, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn basic() { @@ -216,6 +216,25 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } + #[tokio::test] + async fn basic_service_ext() { + // service that requires auth for all requests + let svc = service_fn(echo).require_basic_authorization("foo", "bar"); + + // make a client that adds auth + let mut client = AddAuthorization::basic(svc, "foo", "bar"); + + let res = client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + #[tokio::test] async fn token() { // service that requires auth for all requests @@ -237,6 +256,25 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } + #[tokio::test] + async fn token_service_ext() { + // service that requires auth for all requests + let svc = service_fn(echo).require_bearer_authorization("foo"); + + // make a client that adds auth + let mut client = AddAuthorization::bearer(svc, "foo"); + + let res = client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + #[tokio::test] async fn making_header_sensitive() { let svc = ServiceBuilder::new() diff --git a/tower-http/src/auth/async_require_authorization.rs b/tower-http/src/auth/async_require_authorization.rs index f086add2..59a543ec 100644 --- a/tower-http/src/auth/async_require_authorization.rs +++ b/tower-http/src/auth/async_require_authorization.rs @@ -308,10 +308,10 @@ where mod tests { #[allow(unused_imports)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use futures_util::future::BoxFuture; use http::{header, StatusCode}; - use tower::{BoxError, ServiceBuilder, ServiceExt}; + use tower::{service_fn, BoxError, ServiceBuilder, ServiceExt as TowerServiceExt}; #[derive(Clone, Copy)] struct MyAuth; @@ -383,6 +383,20 @@ mod tests { assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } + #[tokio::test] + async fn require_async_auth_401_service_ext() { + let mut service = service_fn(echo).async_require_authorization(MyAuth); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer deez") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } diff --git a/tower-http/src/catch_panic.rs b/tower-http/src/catch_panic.rs index 3f1c2279..394be57e 100644 --- a/tower-http/src/catch_panic.rs +++ b/tower-http/src/catch_panic.rs @@ -366,10 +366,10 @@ mod tests { #![allow(unreachable_code)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::Response; use std::convert::Infallible; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn panic_before_returning_future() { @@ -406,4 +406,21 @@ mod tests { let body = crate::test_helpers::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"Service panicked"); } + + #[tokio::test] + async fn panic_in_future_service_ext() { + let svc = service_fn(|_: Request| async { + panic!("future panic"); + Ok::<_, Infallible>(Response::new(Body::empty())) + }) + .catch_panic(); + + let req = Request::new(Body::empty()); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + let body = crate::test_helpers::to_bytes(res).await.unwrap(); + assert_eq!(&body[..], b"Service panicked"); + } } diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs index 897304c3..1eb9b570 100644 --- a/tower-http/src/compression/mod.rs +++ b/tower-http/src/compression/mod.rs @@ -92,7 +92,10 @@ mod tests { use crate::compression::predicate::SizeAbove; use super::*; - use crate::test_helpers::{Body, WithTrailers}; + use crate::{ + test_helpers::{Body, WithTrailers}, + ServiceExt, + }; use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder}; use flate2::read::GzDecoder; use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE}; @@ -103,7 +106,7 @@ mod tests { use std::sync::{Arc, RwLock}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_util::io::StreamReader; - use tower::{service_fn, Service, ServiceExt}; + use tower::{service_fn, Service, ServiceExt as TowerServiceExt}; // Compression filter allows every other request to be compressed #[derive(Clone)] @@ -148,6 +151,35 @@ mod tests { assert_eq!(trailers["foo"], "bar"); } + #[tokio::test] + async fn gzip_works_service_ext() { + let mut svc = service_fn(handle).compress_when(Always); + + // call the service + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // read the compressed body + let collected = res.into_body().collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let compressed_data = collected.to_bytes(); + + // decompress the body + // doing this with flate2 as that is much easier than async-compression and blocking during + // tests is fine + let mut decoder = GzDecoder::new(&compressed_data[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + assert_eq!(decompressed, "Hello, World!"); + + // trailers are maintained + assert_eq!(trailers["foo"], "bar"); + } + #[tokio::test] async fn zstd_works() { let svc = service_fn(handle); diff --git a/tower-http/src/decompression/mod.rs b/tower-http/src/decompression/mod.rs index 708df439..f59f56f3 100644 --- a/tower-http/src/decompression/mod.rs +++ b/tower-http/src/decompression/mod.rs @@ -113,13 +113,13 @@ mod tests { use std::io::Write; use super::*; - use crate::test_helpers::Body; use crate::{compression::Compression, test_helpers::WithTrailers}; + use crate::{test_helpers::Body, ServiceExt}; use flate2::write::GzEncoder; use http::Response; use http::{HeaderMap, HeaderName, Request}; use http_body_util::BodyExt; - use tower::{service_fn, Service, ServiceExt}; + use tower::{service_fn, Service, ServiceExt as TowerServiceExt}; #[tokio::test] async fn works() { @@ -143,6 +143,28 @@ mod tests { assert_eq!(trailers["foo"], "bar"); } + #[tokio::test] + async fn works_service_ext() { + let mut client = Compression::new(service_fn(handle)).decompress(); + + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = client.ready().await.unwrap().call(req).await.unwrap(); + + // read the body, it will be decompressed automatically + let body = res.into_body(); + let collected = body.collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let decompressed_data = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + + assert_eq!(decompressed_data, "Hello, World!"); + + // maintains trailers + assert_eq!(trailers["foo"], "bar"); + } + async fn handle(_req: Request) -> Result>, Infallible> { let mut trailers = HeaderMap::new(); trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap()); diff --git a/tower-http/src/follow_redirect/mod.rs b/tower-http/src/follow_redirect/mod.rs index 516fabf7..200582aa 100644 --- a/tower-http/src/follow_redirect/mod.rs +++ b/tower-http/src/follow_redirect/mod.rs @@ -388,10 +388,10 @@ fn resolve_uri(relative: &str, base: &Uri) -> Option { #[cfg(test)] mod tests { use super::{policy::*, *}; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::header::LOCATION; use std::convert::Infallible; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn follows() { @@ -411,6 +411,21 @@ mod tests { ); } + #[tokio::test] + async fn follows_service_ext() { + let svc = service_fn(handle).follow_redirect_with_policy(Action::Follow); + let req = Request::builder() + .uri("http://example.com/42") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(*res.body(), 0); + assert_eq!( + res.extensions().get::().unwrap().0, + "http://example.com/0" + ); + } + #[tokio::test] async fn stops() { let svc = ServiceBuilder::new() diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 4c731e83..377992fa 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -345,6 +345,9 @@ pub mod validate_request; pub mod body; +mod service_ext; +pub use self::service_ext::ServiceExt; + /// The latency unit used to report latencies by middleware. #[non_exhaustive] #[derive(Copy, Clone, Debug)] diff --git a/tower-http/src/metrics/in_flight_requests.rs b/tower-http/src/metrics/in_flight_requests.rs index dbb5e2ff..8968bf3a 100644 --- a/tower-http/src/metrics/in_flight_requests.rs +++ b/tower-http/src/metrics/in_flight_requests.rs @@ -289,9 +289,9 @@ where mod tests { #[allow(unused_imports)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::Request; - use tower::{BoxError, ServiceBuilder}; + use tower::{service_fn, BoxError, ServiceBuilder}; #[tokio::test] async fn basic() { @@ -321,6 +321,32 @@ mod tests { assert_eq!(counter.get(), 0); } + #[tokio::test] + async fn basic_service_ext() { + let counter = InFlightRequestsCounter::new(); + + let mut service = service_fn(echo).count_in_flight_requests(counter.clone()); + assert_eq!(counter.get(), 0); + + // driving service to ready shouldn't increment the counter + std::future::poll_fn(|cx| service.poll_ready(cx)) + .await + .unwrap(); + assert_eq!(counter.get(), 0); + + // creating the response future should increment the count + let response_future = service.call(Request::new(Body::empty())); + assert_eq!(counter.get(), 1); + + // count shouldn't decrement until the full body has been comsumed + let response = response_future.await.unwrap(); + assert_eq!(counter.get(), 1); + + let body = response.into_body(); + crate::test_helpers::to_bytes(body).await.unwrap(); + assert_eq!(counter.get(), 0); + } + async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } diff --git a/tower-http/src/normalize_path.rs b/tower-http/src/normalize_path.rs index efc7be52..4ad663da 100644 --- a/tower-http/src/normalize_path.rs +++ b/tower-http/src/normalize_path.rs @@ -140,8 +140,9 @@ fn normalize_trailing_slash(uri: &mut Uri) { #[cfg(test)] mod tests { use super::*; + use crate::ServiceExt; use std::convert::Infallible; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn works() { @@ -165,6 +166,26 @@ mod tests { assert_eq!(body, "/foo"); } + #[tokio::test] + async fn works_service_ext() { + async fn handle(request: Request<()>) -> Result, Infallible> { + Ok(Response::new(request.uri().to_string())) + } + + let mut svc = service_fn(handle).normalize_path(); + + let body = svc + .ready() + .await + .unwrap() + .call(Request::builder().uri("/foo/").body(()).unwrap()) + .await + .unwrap() + .into_body(); + + assert_eq!(body, "/foo"); + } + #[test] fn is_noop_if_no_trailing_slash() { let mut uri = "/foo".parse::().unwrap(); diff --git a/tower-http/src/request_id.rs b/tower-http/src/request_id.rs index 1db2d02a..2b580339 100644 --- a/tower-http/src/request_id.rs +++ b/tower-http/src/request_id.rs @@ -485,7 +485,7 @@ impl MakeRequestId for MakeRequestUuid { #[cfg(test)] mod tests { use crate::test_helpers::Body; - use crate::ServiceBuilderExt as _; + use crate::{ServiceBuilderExt as _, ServiceExt}; use http::Response; use std::{ convert::Infallible, @@ -494,7 +494,7 @@ mod tests { Arc, }, }; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[allow(unused_imports)] use super::*; @@ -529,6 +529,35 @@ mod tests { assert_eq!(res.extensions().get::().unwrap().0, "2"); } + #[tokio::test] + async fn basic_service_ext() { + let svc = service_fn(handler) + .propagate_x_request_id() + .set_x_request_id(Counter::default()); + + // header on response + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "0"); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "1"); + + // doesn't override if header is already there + let req = Request::builder() + .header("x-request-id", "foo") + .body(Body::empty()) + .unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "foo"); + + // extension propagated + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.extensions().get::().unwrap().0, "2"); + } + #[tokio::test] async fn other_middleware_setting_request_id() { let svc = ServiceBuilder::new() diff --git a/tower-http/src/sensitive_headers.rs b/tower-http/src/sensitive_headers.rs index 3bd081db..a1917515 100644 --- a/tower-http/src/sensitive_headers.rs +++ b/tower-http/src/sensitive_headers.rs @@ -382,8 +382,9 @@ where mod tests { #[allow(unused_imports)] use super::*; + use crate::ServiceExt; use http::header; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn multiple_value_header() { @@ -445,4 +446,62 @@ mod tests { assert!(value.is_sensitive()) } } + + #[tokio::test] + async fn multiple_value_header_service_ext() { + async fn response_set_cookie(req: http::Request<()>) -> Result, ()> { + let mut iter = req.headers().get_all(header::COOKIE).iter().peekable(); + + assert!(iter.peek().is_some()); + + for value in iter { + assert!(value.is_sensitive()) + } + + let mut resp = http::Response::new(()); + resp.headers_mut().append( + header::CONTENT_TYPE, + http::HeaderValue::from_static("text/html"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-1"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-2"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-3"), + ); + Ok(resp) + } + + let mut service = service_fn(response_set_cookie) + .set_sensitive_request_headers(vec![header::COOKIE]) + .set_sensitive_response_headers(vec![header::SET_COOKIE]); + + let mut req = http::Request::new(()); + req.headers_mut() + .append(header::COOKIE, http::HeaderValue::from_static("cookie+1")); + req.headers_mut() + .append(header::COOKIE, http::HeaderValue::from_static("cookie+2")); + + let resp = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(!resp + .headers() + .get(header::CONTENT_TYPE) + .unwrap() + .is_sensitive()); + + let mut iter = resp.headers().get_all(header::SET_COOKIE).iter().peekable(); + + assert!(iter.peek().is_some()); + + for value in iter { + assert!(value.is_sensitive()) + } + } } diff --git a/tower-http/src/service_ext.rs b/tower-http/src/service_ext.rs new file mode 100644 index 00000000..59e3638c --- /dev/null +++ b/tower-http/src/service_ext.rs @@ -0,0 +1,776 @@ +#[cfg(feature = "add-extension")] +use crate::add_extension::AddExtension; +#[cfg(all(feature = "validate-request", feature = "auth"))] +use crate::auth::require_authorization::{Basic, Bearer}; +#[cfg(feature = "auth")] +use crate::auth::{AddAuthorization, AsyncRequireAuthorization}; +#[cfg(feature = "catch-panic")] +use crate::catch_panic::{CatchPanic, DefaultResponseForPanic, ResponseForPanic}; +#[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd" +))] +use crate::compression::{Compression, DefaultPredicate, Predicate}; +#[cfg(feature = "cors")] +use crate::cors::Cors; +#[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd" +))] +use crate::decompression::{Decompression, RequestDecompression}; +#[cfg(feature = "follow-redirect")] +use crate::follow_redirect::{policy::Standard, FollowRedirect}; +#[cfg(feature = "limit")] +use crate::limit::RequestBodyLimit; +#[cfg(feature = "map-request-body")] +use crate::map_request_body::MapRequestBody; +#[cfg(feature = "map-response-body")] +use crate::map_response_body::MapResponseBody; +#[cfg(feature = "metrics")] +use crate::metrics::in_flight_requests::{InFlightRequests, InFlightRequestsCounter}; +#[cfg(feature = "normalize-path")] +use crate::normalize_path::NormalizePath; +#[cfg(feature = "propagate-header")] +use crate::propagate_header::PropagateHeader; +#[cfg(feature = "request-id")] +use crate::request_id::{MakeRequestId, PropagateRequestId, SetRequestId, X_REQUEST_ID}; +#[cfg(feature = "sensitive-headers")] +use crate::sensitive_headers::{ + SetSensitiveHeaders, SetSensitiveRequestHeaders, SetSensitiveResponseHeaders, +}; +#[cfg(feature = "set-header")] +use crate::set_header::{SetRequestHeader, SetResponseHeader}; +#[cfg(feature = "set-status")] +use crate::set_status::SetStatus; +#[cfg(feature = "validate-request")] +use crate::validate_request::{AcceptHeader, ValidateRequestHeader}; +#[cfg(feature = "trace")] +use crate::{ + classify::{GrpcErrorsAsFailures, MakeClassifier, ServerErrorsAsFailures, SharedClassifier}, + trace::{ + DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, + DefaultOnResponse, Trace, + }, +}; +#[cfg(feature = "timeout")] +use { + crate::timeout::{RequestBodyTimeout, ResponseBodyTimeout, Timeout}, + std::time::Duration, +}; +#[allow(unused_imports)] +use { + http::{header::HeaderName, status::StatusCode}, + http_body::Body, +}; + +/// An extension trait for `Service`s that provides a variety of convenient +/// adapters +pub trait ServiceExt: tower_service::Service { + /// Create a new middleware for adding some shareable value to [request extensions]. + /// + /// See the [add_extension](crate::add_extension) for more details. + /// + /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html + #[cfg(feature = "add-extension")] + fn add_extension(self, value: T) -> AddExtension + where + Self: Sized, + { + AddExtension::new(self, value) + } + + /// Authorize requests using a username and password pair. + /// + /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is + /// `base64_encode("{username}:{password}")`. + /// + /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS + /// with this method. However use of HTTPS/TLS is not enforced by this middleware. + /// + /// See the [auth](crate::auth) for more details. + #[cfg(feature = "auth")] + fn require_basic_authorization(self, username: &str, password: &str) -> AddAuthorization + where + Self: Sized, + { + AddAuthorization::basic(self, username, password) + } + + /// Authorize requests using a "bearer token". Commonly used for OAuth 2. + /// + /// The `Authorization` header will be set to `Bearer {token}`. + /// + /// # Panics + /// + /// Panics if the token is not a valid [HeaderValue](http::{HeaderValue). + /// + /// See the [auth](crate::auth) for more details. + #[cfg(feature = "auth")] + fn require_bearer_authorization(self, token: &str) -> AddAuthorization + where + Self: Sized, + { + AddAuthorization::bearer(self, token) + } + + /// Authorize requests using a custom scheme. + /// + /// The `Authorization` header is required to have the value provided. + /// + /// See the [auth](crate::auth) for more details. + #[cfg(feature = "auth")] + fn async_require_authorization(self, auth: T) -> AsyncRequireAuthorization + where + Self: Sized, + { + AsyncRequireAuthorization::new(self, auth) + } + + /// Create a new middleware that catches panics and converts them into + /// `500 Internal Server` responses with the default panic handler. + /// + /// See the [catch_panic](crate::catch_panic) for more details. + #[cfg(feature = "catch-panic")] + fn catch_panic(self) -> CatchPanic + where + Self: Sized, + { + CatchPanic::new(self) + } + + /// Create a new middleware that catches panics and converts them into + /// `500 Internal Server` responses with a custom panic handler. + /// + /// See the [catch_panic](crate::catch_panic) for more details. + #[cfg(feature = "catch-panic")] + fn catch_panic_custom(self, panic_handler: T) -> CatchPanic + where + Self: Sized, + T: ResponseForPanic, + { + CatchPanic::custom(self, panic_handler) + } + + /// Creates a new middleware that compress response bodies of the underlying service. + /// + /// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the + /// `Content-Encoding` header to responses. + /// + /// See the [compression](crate::compression) for more details. + #[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd" + ))] + fn compress(self) -> Compression + where + Self: Sized, + { + Compression::new(self) + } + + /// Creates a new middleware that compress response bodies of the underlying service using + /// a custom predicate to determine whether a response should be compressed or not. + /// + /// See the [compression](crate::compression) for more details. + #[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd" + ))] + fn compress_when(self, predicate: C) -> Compression + where + Self: Sized, + C: Predicate, + { + Compression::new(self).compress_when(predicate) + } + + /// Creates a new middleware that adds headers for [CORS][mdn]. + /// + /// See the [cors](crate::cors) for an example. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + #[cfg(feature = "cors")] + fn add_cors(self) -> Cors + where + Self: Sized, + { + Cors::new(self) + } + + /// Creates a new middleware that adds headers for [CORS][mdn] using a permissive configuration. + /// + /// See the [cors](crate::cors) for an example. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + #[cfg(feature = "cors")] + fn add_cors_permissive(self) -> Cors + where + Self: Sized, + { + Cors::permissive(self) + } + + /// Creates a new middleware that adds headers for [CORS][mdn] using a very permissive configuration. + /// + /// See the [cors](crate::cors) for an example. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + #[cfg(feature = "cors")] + fn add_cors_very_permissive(self) -> Cors + where + Self: Sized, + { + Cors::very_permissive(self) + } + + /// Creates a new middleware that decompresses response bodies of the underlying service. + /// + /// This adds the `Accept-Encoding` header to requests and transparently decompresses response + /// bodies based on the `Content-Encoding` header. + /// + /// See the [decompression](crate::decompression) for more details. + #[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd" + ))] + fn decompress(self) -> Decompression + where + Self: Sized, + { + Decompression::new(self) + } + + /// Creates a new middleware that decompresses request bodies and calls its underlying service. + /// + /// Transparently decompresses request bodies based on the `Content-Encoding` header. + /// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type` + /// status code will be returned with the accepted encodings in the `Accept-Encoding` header. + /// + /// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type` but + /// will call the underlying service with the unmodified request if the encoding is not supported. + /// This is disabled by default. + /// + /// See the [decompression](crate::decompression) for more details. + #[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd" + ))] + fn decompress_request(self) -> RequestDecompression + where + Self: Sized, + { + RequestDecompression::new(self) + } + + /// Creates a new middleware that retries requests with a [`Service`](tower::Service) to follow redirection responses. + /// + /// See the [follow_redirect](crate::follow_redirect) for more details. + #[cfg(feature = "follow-redirect")] + fn follow_redirect(self) -> FollowRedirect + where + Self: Sized, + { + FollowRedirect::new(self) + } + + /// Creates a new middleware that retries requests with a [`Service`](tower::Service) to follow redirection responses + /// with the given redirection [`Policy`](crate::follow_redirect::policy::Policy). + /// + /// See the [follow_redirect](crate::follow_redirect) for more details. + #[cfg(feature = "follow-redirect")] + fn follow_redirect_with_policy

(self, policy: P) -> FollowRedirect + where + Self: Sized, + P: Clone, + { + FollowRedirect::with_policy(self, policy) + } + + /// Creates a new middleware that intercepts requests with body lengths greater than the + /// configured limit and converts them into `413 Payload Too Large` responses. + /// + /// See the [limit](crate::limit) for an example. + #[cfg(feature = "limit")] + fn limit_request_body(self, limit: usize) -> RequestBodyLimit + where + Self: Sized, + { + RequestBodyLimit::new(self, limit) + } + + /// Creates a new middleware that apply a transformation to the request body. + #[cfg(feature = "map-request-body")] + fn map_request_body(self, f: F) -> MapRequestBody + where + Self: Sized, + { + MapRequestBody::new(self, f) + } + + /// Creates a new middleware that apply a transformation to the response body. + #[cfg(feature = "map-response-body")] + fn map_response_body(self, f: F) -> MapResponseBody + where + Self: Sized, + { + MapResponseBody::new(self, f) + } + + /// Creates a new middleware that counts the number of in-flight requests. + #[cfg(feature = "metrics")] + fn count_in_flight_requests(self, counter: InFlightRequestsCounter) -> InFlightRequests + where + Self: Sized, + { + InFlightRequests::new(self, counter) + } + + /// Creates a new middleware that normalizes paths. + /// + /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` + /// will be changed to `/foo` before reaching the inner service. + /// + /// See the [normalize_path](crate::normalize_path) for more details. + #[cfg(feature = "normalize-path")] + fn normalize_path(self) -> NormalizePath + where + Self: Sized, + { + NormalizePath::trim_trailing_slash(self) + } + + /// Creates a new middleware that propagates headers from requests to responses. + /// + /// If the header is present on the request it'll be applied to the response as well. This could + /// for example be used to propagate headers such as `X-Request-Id`. + /// + /// See the [propagate_header](crate::propagate_header) for more details. + #[cfg(feature = "propagate-header")] + fn propagate_header(self, header_name: HeaderName) -> PropagateHeader + where + Self: Sized, + { + PropagateHeader::new(self, header_name) + } + + /// Creates a new middleware that propagate request ids from requests to responses. + /// + /// If the request contains a matching header that header will be applied to responses. If a + /// [`RequestId`](crate::request_id::RequestId) extension is also present it will be propagated as well. + /// + /// See the [request_id](crate::request_id) for an example. + #[cfg(feature = "request-id")] + fn propagate_request_id(self, header_name: HeaderName) -> PropagateRequestId + where + Self: Sized, + { + PropagateRequestId::new(self, header_name) + } + + /// Creates a new middleware that propagate request ids from requests to responses + /// using `x-request-id` as the header name. + /// + /// If the request contains a matching header that header will be applied to responses. If a + /// [`RequestId`](crate::request_id::RequestId) extension is also present it will be propagated as well. + /// + /// See the [request_id](crate::request_id) for an example. + #[cfg(feature = "request-id")] + fn propagate_x_request_id(self) -> PropagateRequestId + where + Self: Sized, + { + PropagateRequestId::new(self, HeaderName::from_static(X_REQUEST_ID)) + } + + /// Creates a new middleware that set request id headers and extensions on requests. + /// + /// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a + /// header with the same name, then the header will be inserted. + /// + /// Additionally [`RequestId`](crate::request_id::RequestId) will be inserted into + /// the Request extensions so other services can access it. + /// + /// See the [request_id](crate::request_id) for an example. + #[cfg(feature = "request-id")] + fn set_request_id(self, header_name: HeaderName, make_request_id: M) -> SetRequestId + where + Self: Sized, + M: MakeRequestId, + { + SetRequestId::new(self, header_name, make_request_id) + } + + /// Creates a new middleware that set request id headers and extensions on requests + /// using `x-request-id` as the header name. + /// + /// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a + /// header with the same name, then the header will be inserted. + /// + /// Additionally [`RequestId`](crate::request_id::RequestId) will be inserted into + /// the Request extensions so other services can access it. + /// + /// See the [request_id](crate::request_id) for an example. + #[cfg(feature = "request-id")] + fn set_x_request_id(self, make_request_id: M) -> SetRequestId + where + Self: Sized, + M: MakeRequestId, + { + SetRequestId::new(self, HeaderName::from_static(X_REQUEST_ID), make_request_id) + } + + /// Creates a new middleware that marks headers as [sensitive]. + /// + /// See the [sensitive_headers](crate::sensitive_headers) for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + #[cfg(feature = "sensitive-headers")] + fn set_sensitive_headers(self, headers: I) -> SetSensitiveHeaders + where + Self: Sized, + I: IntoIterator, + { + use std::iter::FromIterator; + let headers = Vec::from_iter(headers); + SetSensitiveRequestHeaders::new( + SetSensitiveResponseHeaders::new(self, headers.iter().cloned()), + headers.into_iter(), + ) + } + + /// Creates a new middleware that marks request headers as [sensitive]. + /// + /// See the [sensitive_headers](crate::sensitive_headers) for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + #[cfg(feature = "sensitive-headers")] + fn set_sensitive_request_headers(self, headers: I) -> SetSensitiveRequestHeaders + where + Self: Sized, + I: IntoIterator, + { + SetSensitiveRequestHeaders::new(self, headers) + } + + /// Creates a new middleware that marks response headers as [sensitive]. + /// + /// See the [sensitive_headers](crate::sensitive_headers) for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + #[cfg(feature = "sensitive-headers")] + fn set_sensitive_response_headers(self, headers: I) -> SetSensitiveResponseHeaders + where + Self: Sized, + I: IntoIterator, + { + SetSensitiveResponseHeaders::new(self, headers) + } + + /// Creates a new middleware that sets a header on the request. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + #[cfg(feature = "set-header")] + fn override_request_header( + self, + header_name: HeaderName, + make: M, + ) -> SetRequestHeader + where + Self: Sized, + { + SetRequestHeader::overriding(self, header_name, make) + } + + /// Creates a new middleware that sets a header on the request. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. + #[cfg(feature = "set-header")] + fn append_request_header(self, header_name: HeaderName, make: M) -> SetRequestHeader + where + Self: Sized, + { + SetRequestHeader::appending(self, header_name, make) + } + + /// Creates a new middleware that sets a header on the request. + /// + /// If a previous value exists for the header, the new value is not inserted. + #[cfg(feature = "set-header")] + fn set_request_header_if_not_present( + self, + header_name: HeaderName, + make: M, + ) -> SetRequestHeader + where + Self: Sized, + { + SetRequestHeader::if_not_present(self, header_name, make) + } + + /// Creates a new middleware that sets a header on the response. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + #[cfg(feature = "set-header")] + fn override_response_header( + self, + header_name: HeaderName, + make: M, + ) -> SetResponseHeader + where + Self: Sized, + { + SetResponseHeader::overriding(self, header_name, make) + } + + /// Creates a new middleware that sets a header on the response. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. + #[cfg(feature = "set-header")] + fn append_response_header( + self, + header_name: HeaderName, + make: M, + ) -> SetResponseHeader + where + Self: Sized, + { + SetResponseHeader::appending(self, header_name, make) + } + + /// Creates a new middleware that sets a header on the response. + /// + /// If a previous value exists for the header, the new value is not inserted. + #[cfg(feature = "set-header")] + fn set_response_header_if_not_present( + self, + header_name: HeaderName, + make: M, + ) -> SetResponseHeader + where + Self: Sized, + { + SetResponseHeader::if_not_present(self, header_name, make) + } + + /// Creates a new middleware that override status codes. + /// + /// See the [set_status](crate::set_status) for more details. + #[cfg(feature = "set-status")] + fn set_status(self, status: StatusCode) -> SetStatus + where + Self: Sized, + { + SetStatus::new(self, status) + } + + /// Creates a new middleware that applies a timeout to requests. + /// + /// If the request does not complete within the specified timeout it will be aborted and a `408 + /// Request Timeout` response will be sent. + /// + /// See the [timeout](crate::timeout) for an example. + #[cfg(feature = "timeout")] + fn timeout(self, timeout: Duration) -> Timeout + where + Self: Sized, + { + Timeout::new(self, timeout) + } + + /// Creates a new middleware that applies a timeout to request bodies. + /// + /// See the [timeout](crate::timeout) for an example. + #[cfg(feature = "timeout")] + fn timeout_request_body(self, timeout: Duration) -> RequestBodyTimeout + where + Self: Sized, + { + RequestBodyTimeout::new(self, timeout) + } + + /// Creates a new middleware that applies a timeout to response bodies. + /// + /// See the [timeout](crate::timeout) for an example. + #[cfg(feature = "timeout")] + fn timeout_response_body(self, timeout: Duration) -> ResponseBodyTimeout + where + Self: Sized, + { + ResponseBodyTimeout::new(self, timeout) + } + + /// Creates a new middleware that adds high level [tracing] to a [`Service`] + /// using the given [`MakeClassifier`]. + /// + /// See the [trace](crate::trace) for an example. + /// + /// [tracing]: https://crates.io/crates/tracing + /// [`Service`]: tower_service::Service + #[cfg(feature = "trace")] + fn trace( + self, + make_classifier: M, + ) -> Trace< + Self, + M, + DefaultMakeSpan, + DefaultOnRequest, + DefaultOnResponse, + DefaultOnBodyChunk, + DefaultOnEos, + DefaultOnFailure, + > + where + Self: Sized, + M: MakeClassifier, + { + Trace::new(self, make_classifier) + } + + /// Creates a new middleware that adds high level [tracing] to a [`Service`] + /// which supports classifying regular HTTP responses based on the status code. + /// + /// See the [trace](crate::trace) for an example. + /// + /// [tracing]: https://crates.io/crates/tracing + /// [`Service`]: tower_service::Service + #[cfg(feature = "trace")] + fn trace_http( + self, + ) -> Trace< + Self, + SharedClassifier, + DefaultMakeSpan, + DefaultOnRequest, + DefaultOnResponse, + DefaultOnBodyChunk, + DefaultOnEos, + DefaultOnFailure, + > + where + Self: Sized, + { + Trace::new_for_http(self) + } + + /// Creates a new middleware that adds high level [tracing] to a [`Service`] + /// which supports classifying gRPC responses and streams based on the `grpc-status` header. + /// + /// See the [trace](crate::trace) for an example. + /// + /// [tracing]: https://crates.io/crates/tracing + /// [`Service`]: tower_service::Service + #[cfg(feature = "trace")] + fn trace_grpc( + self, + ) -> Trace< + Self, + SharedClassifier, + DefaultMakeSpan, + DefaultOnRequest, + DefaultOnResponse, + DefaultOnBodyChunk, + DefaultOnEos, + DefaultOnFailure, + > + where + Self: Sized, + { + Trace::new_for_grpc(self) + } + + /// Creates a new middleware that authorize requests using a username and password pair. + /// + /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is + /// `base64_encode("{username}:{password}")`. + /// + /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS + /// with this method. However use of HTTPS/TLS is not enforced by this middleware. + /// + /// See the [validate_request](crate::validate_request) for an example. + #[cfg(all(feature = "validate-request", feature = "auth"))] + fn validate_basic_authorization( + self, + username: &str, + password: &str, + ) -> ValidateRequestHeader> + where + Self: Sized, + Resbody: Body + Default, + { + ValidateRequestHeader::basic(self, username, password) + } + + /// Creates a new middleware that authorize requests using a "bearer token". + /// Commonly used for OAuth 2. + /// + /// The `Authorization` header is required to be `Bearer {token}`. + /// + /// # Panics + /// + /// Panics if the token is not a valid [`HeaderValue`](http::header::HeaderValue). + /// + /// See the [validate_request](crate::validate_request) for an example. + #[cfg(all(feature = "validate-request", feature = "auth"))] + fn validate_bearer_authorization( + self, + token: &str, + ) -> ValidateRequestHeader> + where + Self: Sized, + Resbody: Body + Default, + { + ValidateRequestHeader::bearer(self, token) + } + + /// Creates a new middleware that authorize requests that have the required Accept header. + /// + /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, + /// as configured. + /// + /// # Panics + /// + /// See `AcceptHeader::new` for when this method panics. + /// + /// See the [validate_request](crate::validate_request) for an example. + #[cfg(feature = "validate-request")] + fn validate_accept_header( + self, + value: &str, + ) -> ValidateRequestHeader> + where + Self: Sized, + Resbody: Body + Default, + { + ValidateRequestHeader::accept(self, value) + } + + /// Creates a new middleware that authorize requests using a custom method. + /// + /// See the [validate_request](crate::validate_request) for an example. + #[cfg(feature = "validate-request")] + fn validate(self, validate: T) -> ValidateRequestHeader + where + Self: Sized, + { + ValidateRequestHeader::custom(self, validate) + } +} + +impl ServiceExt for T where T: tower_service::Service + Sized {} diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index 327266af..e7ccde18 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -416,9 +416,9 @@ where mod tests { #[allow(unused_imports)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::header; - use tower::{BoxError, ServiceBuilder, ServiceExt}; + use tower::{service_fn, BoxError, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn valid_accept_header() { @@ -436,6 +436,20 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } + #[tokio::test] + async fn valid_accept_header_service_ext() { + let mut service = service_fn(echo).validate_accept_header("application/json"); + + let request = Request::get("/") + .header(header::ACCEPT, "application/json") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + #[tokio::test] async fn valid_accept_header_accept_all_json() { let mut service = ServiceBuilder::new()