From 12ee7779e9b5be5d697ab43eb7e959dd27ef90cf Mon Sep 17 00:00:00 2001 From: Guillermo Lloret Talavera Date: Wed, 20 Sep 2023 21:40:58 +0200 Subject: [PATCH 1/6] Add ServiceExt trait --- tower-http/src/lib.rs | 3 + tower-http/src/service_ext.rs | 429 ++++++++++++++++++++++++++++++++++ 2 files changed, 432 insertions(+) create mode 100644 tower-http/src/service_ext.rs diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 6719ddbd..00e17916 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -342,6 +342,9 @@ pub use self::builder::ServiceBuilderExt; #[cfg(feature = "validate-request")] pub mod validate_request; +mod service_ext; +pub use service_ext::ServiceExt; + /// The latency unit used to report latencies by middleware. #[non_exhaustive] #[derive(Copy, Clone, Debug)] diff --git a/tower-http/src/service_ext.rs b/tower-http/src/service_ext.rs new file mode 100644 index 00000000..7db4b48d --- /dev/null +++ b/tower-http/src/service_ext.rs @@ -0,0 +1,429 @@ +#![allow(missing_docs)] // todo + +#[cfg(feature = "add-extension")] +use crate::add_extension::AddExtension; +#[cfg(feature = "auth")] +use crate::auth::{AddAuthorization, AsyncRequireAuthorization}; +#[cfg(feature = "catch-panic")] +use crate::catch_panic::{CatchPanic, DefaultResponseForPanic, ResponseForPanic}; +#[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd" +))] +use crate::compression::{Compression, DefaultPredicate, Predicate}; +#[cfg(feature = "cors")] +use crate::cors::Cors; +#[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd" +))] +use crate::decompression::{Decompression, RequestDecompression}; +#[cfg(feature = "follow-redirect")] +use crate::follow_redirect::{policy::Standard, FollowRedirect}; +#[cfg(feature = "limit")] +use crate::limit::RequestBodyLimit; +#[cfg(feature = "map-request-body")] +use crate::map_request_body::MapRequestBody; +#[cfg(feature = "map-response-body")] +use crate::map_response_body::MapResponseBody; +#[cfg(feature = "metrics")] +use crate::metrics::in_flight_requests::{InFlightRequests, InFlightRequestsCounter}; +#[cfg(feature = "normalize-path")] +use crate::normalize_path::NormalizePath; +#[cfg(feature = "propagate-header")] +use crate::propagate_header::PropagateHeader; +#[cfg(feature = "request-id")] +use crate::request_id::{MakeRequestId, PropagateRequestId, SetRequestId, X_REQUEST_ID}; +#[cfg(feature = "sensitive-headers")] +use crate::sensitive_headers::{ + SetSensitiveHeaders, SetSensitiveRequestHeaders, SetSensitiveResponseHeaders, +}; +#[cfg(feature = "set-header")] +use crate::set_header::{SetRequestHeader, SetResponseHeader}; +#[cfg(feature = "set-status")] +use crate::set_status::SetStatus; +#[cfg(feature = "validate-request")] +use crate::{ + auth::require_authorization::{Basic, Bearer}, + validate_request::{AcceptHeader, ValidateRequestHeader}, +}; +#[cfg(feature = "trace")] +use crate::{ + classify::{GrpcErrorsAsFailures, MakeClassifier, ServerErrorsAsFailures, SharedClassifier}, + trace::{ + DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, + DefaultOnResponse, Trace, + }, +}; +#[cfg(feature = "timeout")] +use { + crate::timeout::{RequestBodyTimeout, ResponseBodyTimeout, Timeout}, + std::time::Duration, +}; +#[allow(unused_imports)] +use { + http::{header::HeaderName, status::StatusCode}, + http_body::Body, +}; + +pub trait ServiceExt: tower_service::Service + Sized { + #[cfg(feature = "add-extension")] + fn add_extension(self, value: T) -> AddExtension { + AddExtension::new(self, value) + } + + #[cfg(feature = "auth")] + fn require_basic_authorization(self, username: &str, password: &str) -> AddAuthorization { + AddAuthorization::basic(self, username, password) + } + + #[cfg(feature = "auth")] + fn require_bearer_authorization(self, token: &str) -> AddAuthorization { + AddAuthorization::bearer(self, token) + } + + #[cfg(feature = "auth")] + fn async_require_authorization(self, auth: T) -> AsyncRequireAuthorization { + AsyncRequireAuthorization::new(self, auth) + } + + #[cfg(feature = "catch-panic")] + fn catch_panic(self) -> CatchPanic { + CatchPanic::new(self) + } + + #[cfg(feature = "catch-panic")] + fn catch_panic_custom(self, panic_handler: T) -> CatchPanic + where + T: ResponseForPanic, + { + CatchPanic::custom(self, panic_handler) + } + + #[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd" + ))] + fn compress(self) -> Compression { + Compression::new(self) + } + + #[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd" + ))] + fn compress_when(self, predicate: C) -> Compression + where + C: Predicate, + { + Compression::new(self).compress_when(predicate) + } + + #[cfg(feature = "cors")] + fn add_cors(self) -> Cors { + Cors::new(self) + } + + #[cfg(feature = "cors")] + fn add_cors_permissive(self) -> Cors { + Cors::permissive(self) + } + + #[cfg(feature = "cors")] + fn add_cors_very_permissive(self) -> Cors { + Cors::very_permissive(self) + } + + #[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd" + ))] + fn decompress(self) -> Decompression { + Decompression::new(self) + } + + #[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd" + ))] + fn decompress_request(self) -> RequestDecompression { + RequestDecompression::new(self) + } + + #[cfg(feature = "follow-redirect")] + fn follow_redirect(self) -> FollowRedirect { + FollowRedirect::new(self) + } + + #[cfg(feature = "follow-redirect")] + fn follow_redirect_with_policy

(self, policy: P) -> FollowRedirect + where + P: Clone, + { + FollowRedirect::with_policy(self, policy) + } + + #[cfg(feature = "limit")] + fn limit_request_body(self, limit: usize) -> RequestBodyLimit { + RequestBodyLimit::new(self, limit) + } + + #[cfg(feature = "map-request-body")] + fn map_request_body(self, f: F) -> MapRequestBody { + MapRequestBody::new(self, f) + } + + #[cfg(feature = "map-response-body")] + fn map_response_body(self, f: F) -> MapResponseBody { + MapResponseBody::new(self, f) + } + + #[cfg(feature = "metrics")] + fn count_in_flight_requests(self, counter: InFlightRequestsCounter) -> InFlightRequests { + InFlightRequests::new(self, counter) + } + + #[cfg(feature = "normalize-path")] + fn normalize_path(self) -> NormalizePath { + NormalizePath::trim_trailing_slash(self) + } + + #[cfg(feature = "propagate-header")] + fn propagate_header(self, header_name: HeaderName) -> PropagateHeader { + PropagateHeader::new(self, header_name) + } + + #[cfg(feature = "request-id")] + fn propagate_request_id(self, header_name: HeaderName) -> PropagateRequestId { + PropagateRequestId::new(self, header_name) + } + + #[cfg(feature = "request-id")] + fn propagate_x_request_id(self) -> PropagateRequestId { + PropagateRequestId::new(self, HeaderName::from_static(X_REQUEST_ID)) + } + + #[cfg(feature = "request-id")] + fn set_request_id(self, header_name: HeaderName, make_request_id: M) -> SetRequestId + where + M: MakeRequestId, + { + SetRequestId::new(self, header_name, make_request_id) + } + + #[cfg(feature = "request-id")] + fn set_x_request_id( + self, + make_request_id: M, + ) -> SetRequestId + where + M: MakeRequestId, + { + SetRequestId::new(self, HeaderName::from_static(X_REQUEST_ID), make_request_id) + } + + #[cfg(feature = "sensitive-headers")] + fn set_sensitive_headers(self, headers: I) -> SetSensitiveHeaders + where + I: IntoIterator, + { + use std::iter::FromIterator; + let headers = Vec::from_iter(headers); + SetSensitiveRequestHeaders::new( + SetSensitiveResponseHeaders::new(self, headers.iter().cloned()), + headers.into_iter(), + ) + } + + #[cfg(feature = "sensitive-headers")] + fn set_sensitive_request_headers(self, headers: I) -> SetSensitiveRequestHeaders + where + I: IntoIterator, + { + SetSensitiveRequestHeaders::new(self, headers) + } + + #[cfg(feature = "sensitive-headers")] + fn set_sensitive_response_headers(self, headers: I) -> SetSensitiveResponseHeaders + where + I: IntoIterator, + { + SetSensitiveResponseHeaders::new(self, headers) + } + + #[cfg(feature = "set-header")] + fn override_request_header( + self, + header_name: HeaderName, + make: M, + ) -> SetRequestHeader { + SetRequestHeader::overriding(self, header_name, make) + } + + #[cfg(feature = "set-header")] + fn append_request_header( + self, + header_name: HeaderName, + make: M, + ) -> SetRequestHeader { + SetRequestHeader::appending(self, header_name, make) + } + + #[cfg(feature = "set-header")] + fn set_request_header_if_not_present( + self, + header_name: HeaderName, + make: M, + ) -> SetRequestHeader { + SetRequestHeader::if_not_present(self, header_name, make) + } + + #[cfg(feature = "set-header")] + fn override_response_header( + self, + header_name: HeaderName, + make: M, + ) -> SetResponseHeader { + SetResponseHeader::overriding(self, header_name, make) + } + + #[cfg(feature = "set-header")] + fn append_response_header( + self, + header_name: HeaderName, + make: M, + ) -> SetResponseHeader { + SetResponseHeader::appending(self, header_name, make) + } + + #[cfg(feature = "set-header")] + fn set_response_header_if_not_present( + self, + header_name: HeaderName, + make: M, + ) -> SetResponseHeader { + SetResponseHeader::if_not_present(self, header_name, make) + } + + #[cfg(feature = "set-status")] + fn set_status(self, status: StatusCode) -> SetStatus { + SetStatus::new(self, status) + } + + fn timeout(self, timeout: Duration) -> Timeout { + Timeout::new(self, timeout) + } + + fn timeout_request_body(self, timeout: Duration) -> RequestBodyTimeout { + RequestBodyTimeout::new(self, timeout) + } + + fn timeout_response_body(self, timeout: Duration) -> ResponseBodyTimeout { + ResponseBodyTimeout::new(self, timeout) + } + + #[cfg(feature = "trace")] + fn trace( + self, + make_classifier: M, + ) -> Trace< + Self, + M, + DefaultMakeSpan, + DefaultOnRequest, + DefaultOnResponse, + DefaultOnBodyChunk, + DefaultOnEos, + DefaultOnFailure, + > + where + M: MakeClassifier, + { + Trace::new(self, make_classifier) + } + + #[cfg(feature = "trace")] + fn trace_http( + self, + ) -> Trace< + Self, + SharedClassifier, + DefaultMakeSpan, + DefaultOnRequest, + DefaultOnResponse, + DefaultOnBodyChunk, + DefaultOnEos, + DefaultOnFailure, + > { + Trace::new_for_http(self) + } + + #[cfg(feature = "trace")] + fn trace_grpc( + self, + ) -> Trace< + Self, + SharedClassifier, + DefaultMakeSpan, + DefaultOnRequest, + DefaultOnResponse, + DefaultOnBodyChunk, + DefaultOnEos, + DefaultOnFailure, + > { + Trace::new_for_grpc(self) + } + + #[cfg(feature = "validate-request")] + fn validate_basic_authorization( + self, + username: &str, + password: &str, + ) -> ValidateRequestHeader> + where + Resbody: Body + Default, + { + ValidateRequestHeader::basic(self, username, password) + } + + #[cfg(feature = "validate-request")] + fn validate_bearer_authorization( + self, + token: &str, + ) -> ValidateRequestHeader> + where + Resbody: Body + Default, + { + ValidateRequestHeader::bearer(self, token) + } + + #[cfg(feature = "validate-request")] + fn validate_accept_header( + self, + value: &str, + ) -> ValidateRequestHeader> + where + Resbody: Body + Default, + { + ValidateRequestHeader::accept(self, value) + } + + #[cfg(feature = "validate-request")] + fn validate(self, validate: T) -> ValidateRequestHeader { + ValidateRequestHeader::custom(self, validate) + } +} + +impl ServiceExt for T where T: tower_service::Service + Sized {} From a8dae7a882d1226174828e1c8879c6526a4cb288 Mon Sep 17 00:00:00 2001 From: Guillermo Lloret Talavera Date: Sun, 26 Nov 2023 10:24:45 +0100 Subject: [PATCH 2/6] Add ServiceExt docs --- tower-http/src/service_ext.rs | 434 ++++++++++++++++++++++++++++++---- 1 file changed, 391 insertions(+), 43 deletions(-) diff --git a/tower-http/src/service_ext.rs b/tower-http/src/service_ext.rs index 7db4b48d..8a29ca3f 100644 --- a/tower-http/src/service_ext.rs +++ b/tower-http/src/service_ext.rs @@ -1,4 +1,4 @@ -#![allow(missing_docs)] // todo +// todo: tests #[cfg(feature = "add-extension")] use crate::add_extension::AddExtension; @@ -70,50 +70,117 @@ use { http_body::Body, }; -pub trait ServiceExt: tower_service::Service + Sized { +/// An extension trait for `Service`s that provides a variety of convenient +/// adapters +pub trait ServiceExt: tower_service::Service { + /// Create a new middleware for adding some shareable value to [request extensions]. + /// + /// See the [add_extension](crate::add_extension) for more details. + /// + /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html #[cfg(feature = "add-extension")] - fn add_extension(self, value: T) -> AddExtension { + fn add_extension(self, value: T) -> AddExtension + where + Self: Sized, + { AddExtension::new(self, value) } + /// Authorize requests using a username and password pair. + /// + /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is + /// `base64_encode("{username}:{password}")`. + /// + /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS + /// with this method. However use of HTTPS/TLS is not enforced by this middleware. + /// + /// See the [auth](crate::auth) for more details. #[cfg(feature = "auth")] - fn require_basic_authorization(self, username: &str, password: &str) -> AddAuthorization { + fn require_basic_authorization(self, username: &str, password: &str) -> AddAuthorization + where + Self: Sized, + { AddAuthorization::basic(self, username, password) } + /// Authorize requests using a "bearer token". Commonly used for OAuth 2. + /// + /// The `Authorization` header will be set to `Bearer {token}`. + /// + /// # Panics + /// + /// Panics if the token is not a valid [`HeaderValue`]. + /// + /// See the [auth](crate::auth) for more details. #[cfg(feature = "auth")] - fn require_bearer_authorization(self, token: &str) -> AddAuthorization { + fn require_bearer_authorization(self, token: &str) -> AddAuthorization + where + Self: Sized, + { AddAuthorization::bearer(self, token) } + /// Authorize requests using a custom scheme. + /// + /// The `Authorization` header is required to have the value provided. + /// + /// See the [auth](crate::auth) for more details. #[cfg(feature = "auth")] - fn async_require_authorization(self, auth: T) -> AsyncRequireAuthorization { + fn async_require_authorization(self, auth: T) -> AsyncRequireAuthorization + where + Self: Sized, + { AsyncRequireAuthorization::new(self, auth) } + /// Create a new middleware that catches panics and converts them into + /// `500 Internal Server` responses with the default panic handler. + /// + /// See the [catch_panic](crate::catch_panic) for more details. #[cfg(feature = "catch-panic")] - fn catch_panic(self) -> CatchPanic { + fn catch_panic(self) -> CatchPanic + where + Self: Sized, + { CatchPanic::new(self) } + /// Create a new middleware that catches panics and converts them into + /// `500 Internal Server` responses with a custom panic handler. + /// + /// See the [catch_panic](crate::catch_panic) for more details. #[cfg(feature = "catch-panic")] fn catch_panic_custom(self, panic_handler: T) -> CatchPanic where + Self: Sized, T: ResponseForPanic, { CatchPanic::custom(self, panic_handler) } + /// Creates a new middleware that compress response bodies of the underlying service. + /// + /// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the + /// `Content-Encoding` header to responses. + /// + /// See the [compression](crate::compression) for more details. #[cfg(any( feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", feature = "compression-zstd" ))] - fn compress(self) -> Compression { + fn compress(self) -> Compression + where + Self: Sized, + { Compression::new(self) } + /// Creates a new middleware that compress response bodies of the underlying service using + /// a custom predicate to determine whether a response should be compressed or not. + /// + /// See the [compression](crate::compression) for more details. #[cfg(any( feature = "compression-br", feature = "compression-deflate", @@ -122,121 +189,261 @@ pub trait ServiceExt: tower_service::Service + Sized { ))] fn compress_when(self, predicate: C) -> Compression where + Self: Sized, C: Predicate, { Compression::new(self).compress_when(predicate) } + /// Creates a new middleware that adds headers for [CORS][mdn]. + /// + /// See the [cors](crate::cors) for an example. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[cfg(feature = "cors")] - fn add_cors(self) -> Cors { + fn add_cors(self) -> Cors + where + Self: Sized, + { Cors::new(self) } + /// Creates a new middleware that adds headers for [CORS][mdn] using a permissive configuration. + /// + /// See the [cors](crate::cors) for an example. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + #[cfg(feature = "cors")] #[cfg(feature = "cors")] - fn add_cors_permissive(self) -> Cors { + fn add_cors_permissive(self) -> Cors + where + Self: Sized, + { Cors::permissive(self) } + /// Creates a new middleware that adds headers for [CORS][mdn] using a very permissive configuration. + /// + /// See the [cors](crate::cors) for an example. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[cfg(feature = "cors")] - fn add_cors_very_permissive(self) -> Cors { + fn add_cors_very_permissive(self) -> Cors + where + Self: Sized, + { Cors::very_permissive(self) } + /// Creates a new middleware that decompresses response bodies of the underlying service. + /// + /// This adds the `Accept-Encoding` header to requests and transparently decompresses response + /// bodies based on the `Content-Encoding` header. + /// + /// See the [decompression](crate::decompression) for more details. #[cfg(any( feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", feature = "decompression-zstd" ))] - fn decompress(self) -> Decompression { + fn decompress(self) -> Decompression + where + Self: Sized, + { Decompression::new(self) } + /// Creates a new middleware that decompresses request bodies and calls its underlying service. + /// + /// Transparently decompresses request bodies based on the `Content-Encoding` header. + /// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type` + /// status code will be returned with the accepted encodings in the `Accept-Encoding` header. + /// + /// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type` but + /// will call the underlying service with the unmodified request if the encoding is not supported. + /// This is disabled by default. + /// + /// See the [decompression](crate::decompression) for more details. #[cfg(any( feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", feature = "decompression-zstd" ))] - fn decompress_request(self) -> RequestDecompression { + fn decompress_request(self) -> RequestDecompression + where + Self: Sized, + { RequestDecompression::new(self) } + /// Creates a new middleware that retries requests with a [`Service`] to follow redirection responses. + /// + /// See the [follow_redirect](crate::follow_redirect) for more details. #[cfg(feature = "follow-redirect")] - fn follow_redirect(self) -> FollowRedirect { + fn follow_redirect(self) -> FollowRedirect + where + Self: Sized, + { FollowRedirect::new(self) } + /// Creates a new middleware that retries requests with a [`Service`] to follow redirection responses + /// with the given redirection [`Policy`]. + /// + /// See the [follow_redirect](crate::follow_redirect) for more details. #[cfg(feature = "follow-redirect")] fn follow_redirect_with_policy

(self, policy: P) -> FollowRedirect where + Self: Sized, P: Clone, { FollowRedirect::with_policy(self, policy) } + /// Creates a new middleware that intercepts requests with body lengths greater than the + /// configured limit and converts them into `413 Payload Too Large` responses. + /// + /// See the [limit](crate::limit) for an example. #[cfg(feature = "limit")] - fn limit_request_body(self, limit: usize) -> RequestBodyLimit { + fn limit_request_body(self, limit: usize) -> RequestBodyLimit + where + Self: Sized, + { RequestBodyLimit::new(self, limit) } + /// Creates a new middleware that apply a transformation to the request body. #[cfg(feature = "map-request-body")] - fn map_request_body(self, f: F) -> MapRequestBody { + fn map_request_body(self, f: F) -> MapRequestBody + where + Self: Sized, + { MapRequestBody::new(self, f) } + /// Creates a new middleware that apply a transformation to the response body. #[cfg(feature = "map-response-body")] - fn map_response_body(self, f: F) -> MapResponseBody { + fn map_response_body(self, f: F) -> MapResponseBody + where + Self: Sized, + { MapResponseBody::new(self, f) } + /// Creates a new middleware that counts the number of in-flight requests. #[cfg(feature = "metrics")] - fn count_in_flight_requests(self, counter: InFlightRequestsCounter) -> InFlightRequests { + fn count_in_flight_requests(self, counter: InFlightRequestsCounter) -> InFlightRequests + where + Self: Sized, + { InFlightRequests::new(self, counter) } + /// Creates a new middleware that normalizes paths. + /// + /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` + /// will be changed to `/foo` before reaching the inner service. + /// + /// See the [normalize_path](crate::normalize_path) for more details. #[cfg(feature = "normalize-path")] - fn normalize_path(self) -> NormalizePath { + fn normalize_path(self) -> NormalizePath + where + Self: Sized, + { NormalizePath::trim_trailing_slash(self) } + /// Creates a new middleware that propagates headers from requests to responses. + /// + /// If the header is present on the request it'll be applied to the response as well. This could + /// for example be used to propagate headers such as `X-Request-Id`. + /// + /// See the [propagate_header](crate::propagate_header) for more details. #[cfg(feature = "propagate-header")] - fn propagate_header(self, header_name: HeaderName) -> PropagateHeader { + fn propagate_header(self, header_name: HeaderName) -> PropagateHeader + where + Self: Sized, + { PropagateHeader::new(self, header_name) } + /// Creates a new middleware that propagate request ids from requests to responses. + /// + /// If the request contains a matching header that header will be applied to responses. If a + /// [`RequestId`] extension is also present it will be propagated as well. + /// + /// See the [request_id](crate::request_id) for an example. #[cfg(feature = "request-id")] - fn propagate_request_id(self, header_name: HeaderName) -> PropagateRequestId { + fn propagate_request_id(self, header_name: HeaderName) -> PropagateRequestId + where + Self: Sized, + { PropagateRequestId::new(self, header_name) } + /// Creates a new middleware that propagate request ids from requests to responses + /// using `x-request-id` as the header name. + /// + /// If the request contains a matching header that header will be applied to responses. If a + /// [`RequestId`] extension is also present it will be propagated as well. + /// + /// See the [request_id](crate::request_id) for an example. #[cfg(feature = "request-id")] - fn propagate_x_request_id(self) -> PropagateRequestId { + fn propagate_x_request_id(self) -> PropagateRequestId + where + Self: Sized, + { PropagateRequestId::new(self, HeaderName::from_static(X_REQUEST_ID)) } + /// Creates a new middleware that set request id headers and extensions on requests. + /// + /// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a + /// header with the same name, then the header will be inserted. + /// + /// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other + /// services can access it. + /// + /// See the [request_id](crate::request_id) for an example. #[cfg(feature = "request-id")] fn set_request_id(self, header_name: HeaderName, make_request_id: M) -> SetRequestId where + Self: Sized, M: MakeRequestId, { SetRequestId::new(self, header_name, make_request_id) } + /// Creates a new middleware that set request id headers and extensions on requests + /// using `x-request-id` as the header name. + /// + /// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a + /// header with the same name, then the header will be inserted. + /// + /// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other + /// services can access it. + /// + /// See the [request_id](crate::request_id) for an example. #[cfg(feature = "request-id")] - fn set_x_request_id( - self, - make_request_id: M, - ) -> SetRequestId + fn set_x_request_id(self, make_request_id: M) -> SetRequestId where + Self: Sized, M: MakeRequestId, { SetRequestId::new(self, HeaderName::from_static(X_REQUEST_ID), make_request_id) } + /// Creates a new middleware that marks headers as [sensitive]. + /// + /// See the [sensitive_headers](crate::sensitive_headers) for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[cfg(feature = "sensitive-headers")] fn set_sensitive_headers(self, headers: I) -> SetSensitiveHeaders where + Self: Sized, I: IntoIterator, { use std::iter::FromIterator; @@ -247,93 +454,175 @@ pub trait ServiceExt: tower_service::Service + Sized { ) } + /// Creates a new middleware that marks request headers as [sensitive]. + /// + /// See the [sensitive_headers](crate::sensitive_headers) for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[cfg(feature = "sensitive-headers")] fn set_sensitive_request_headers(self, headers: I) -> SetSensitiveRequestHeaders where + Self: Sized, I: IntoIterator, { SetSensitiveRequestHeaders::new(self, headers) } + /// Creates a new middleware that marks response headers as [sensitive]. + /// + /// See the [sensitive_headers](crate::sensitive_headers) for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[cfg(feature = "sensitive-headers")] fn set_sensitive_response_headers(self, headers: I) -> SetSensitiveResponseHeaders where + Self: Sized, I: IntoIterator, { SetSensitiveResponseHeaders::new(self, headers) } + /// Creates a new middleware that sets a header on the request. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. #[cfg(feature = "set-header")] fn override_request_header( self, header_name: HeaderName, make: M, - ) -> SetRequestHeader { + ) -> SetRequestHeader + where + Self: Sized, + { SetRequestHeader::overriding(self, header_name, make) } + /// Creates a new middleware that sets a header on the request. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. #[cfg(feature = "set-header")] - fn append_request_header( - self, - header_name: HeaderName, - make: M, - ) -> SetRequestHeader { + fn append_request_header(self, header_name: HeaderName, make: M) -> SetRequestHeader + where + Self: Sized, + { SetRequestHeader::appending(self, header_name, make) } + /// Creates a new middleware that sets a header on the request. + /// + /// If a previous value exists for the header, the new value is not inserted. #[cfg(feature = "set-header")] fn set_request_header_if_not_present( self, header_name: HeaderName, make: M, - ) -> SetRequestHeader { + ) -> SetRequestHeader + where + Self: Sized, + { SetRequestHeader::if_not_present(self, header_name, make) } + /// Creates a new middleware that sets a header on the response. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. #[cfg(feature = "set-header")] fn override_response_header( self, header_name: HeaderName, make: M, - ) -> SetResponseHeader { + ) -> SetResponseHeader + where + Self: Sized, + { SetResponseHeader::overriding(self, header_name, make) } + /// Creates a new middleware that sets a header on the response. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. #[cfg(feature = "set-header")] fn append_response_header( self, header_name: HeaderName, make: M, - ) -> SetResponseHeader { + ) -> SetResponseHeader + where + Self: Sized, + { SetResponseHeader::appending(self, header_name, make) } + /// Creates a new middleware that sets a header on the response. + /// + /// If a previous value exists for the header, the new value is not inserted. #[cfg(feature = "set-header")] fn set_response_header_if_not_present( self, header_name: HeaderName, make: M, - ) -> SetResponseHeader { + ) -> SetResponseHeader + where + Self: Sized, + { SetResponseHeader::if_not_present(self, header_name, make) } + /// Creates a new middleware that override status codes. + /// + /// See the [set_status](crate::set_status) for more details. #[cfg(feature = "set-status")] - fn set_status(self, status: StatusCode) -> SetStatus { + fn set_status(self, status: StatusCode) -> SetStatus + where + Self: Sized, + { SetStatus::new(self, status) } - fn timeout(self, timeout: Duration) -> Timeout { + /// Creates a new middleware that applies a timeout to requests. + /// + /// If the request does not complete within the specified timeout it will be aborted and a `408 + /// Request Timeout` response will be sent. + /// + /// See the [timeout](crate::timeout) for an example. + fn timeout(self, timeout: Duration) -> Timeout + where + Self: Sized, + { Timeout::new(self, timeout) } - fn timeout_request_body(self, timeout: Duration) -> RequestBodyTimeout { + /// Creates a new middleware that applies a timeout to request bodies. + /// + /// See the [timeout](crate::timeout) for an example. + fn timeout_request_body(self, timeout: Duration) -> RequestBodyTimeout + where + Self: Sized, + { RequestBodyTimeout::new(self, timeout) } - fn timeout_response_body(self, timeout: Duration) -> ResponseBodyTimeout { + /// Creates a new middleware that applies a timeout to response bodies. + /// + /// See the [timeout](crate::timeout) for an example. + fn timeout_response_body(self, timeout: Duration) -> ResponseBodyTimeout + where + Self: Sized, + { ResponseBodyTimeout::new(self, timeout) } + /// Creates a new middleware that adds high level [tracing] to a [`Service`] + /// using the given [`MakeClassifier`]. + /// + /// See the [trace](crate::trace) for an example. + /// + /// [tracing]: https://crates.io/crates/tracing + /// [`Service`]: tower_service::Service #[cfg(feature = "trace")] fn trace( self, @@ -349,11 +638,19 @@ pub trait ServiceExt: tower_service::Service + Sized { DefaultOnFailure, > where + Self: Sized, M: MakeClassifier, { Trace::new(self, make_classifier) } + /// Creates a new middleware that adds high level [tracing] to a [`Service`] + /// which supports classifying regular HTTP responses based on the status code. + /// + /// See the [trace](crate::trace) for an example. + /// + /// [tracing]: https://crates.io/crates/tracing + /// [`Service`]: tower_service::Service #[cfg(feature = "trace")] fn trace_http( self, @@ -366,10 +663,20 @@ pub trait ServiceExt: tower_service::Service + Sized { DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, - > { + > + where + Self: Sized, + { Trace::new_for_http(self) } + /// Creates a new middleware that adds high level [tracing] to a [`Service`] + /// which supports classifying gRPC responses and streams based on the `grpc-status` header. + /// + /// See the [trace](crate::trace) for an example. + /// + /// [tracing]: https://crates.io/crates/tracing + /// [`Service`]: tower_service::Service #[cfg(feature = "trace")] fn trace_grpc( self, @@ -382,10 +689,22 @@ pub trait ServiceExt: tower_service::Service + Sized { DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, - > { + > + where + Self: Sized, + { Trace::new_for_grpc(self) } + /// Creates a new middleware that authorize requests using a username and password pair. + /// + /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is + /// `base64_encode("{username}:{password}")`. + /// + /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS + /// with this method. However use of HTTPS/TLS is not enforced by this middleware. + /// + /// See the [validate_request](crate::validate_request) for an example. #[cfg(feature = "validate-request")] fn validate_basic_authorization( self, @@ -393,35 +712,64 @@ pub trait ServiceExt: tower_service::Service + Sized { password: &str, ) -> ValidateRequestHeader> where + Self: Sized, Resbody: Body + Default, { ValidateRequestHeader::basic(self, username, password) } + /// Creates a new middleware that authorize requests using a "bearer token". + /// Commonly used for OAuth 2. + /// + /// The `Authorization` header is required to be `Bearer {token}`. + /// + /// # Panics + /// + /// Panics if the token is not a valid [`HeaderValue`]. + /// + /// See the [validate_request](crate::validate_request) for an example. #[cfg(feature = "validate-request")] fn validate_bearer_authorization( self, token: &str, ) -> ValidateRequestHeader> where + Self: Sized, Resbody: Body + Default, { ValidateRequestHeader::bearer(self, token) } + /// Creates a new middleware that authorize requests that have the required Accept header. + /// + /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, + /// as configured. + /// + /// # Panics + /// + /// See `AcceptHeader::new` for when this method panics. + /// + /// See the [validate_request](crate::validate_request) for an example. #[cfg(feature = "validate-request")] fn validate_accept_header( self, value: &str, ) -> ValidateRequestHeader> where + Self: Sized, Resbody: Body + Default, { ValidateRequestHeader::accept(self, value) } + /// Creates a new middleware that authorize requests using a custom method. + /// + /// See the [validate_request](crate::validate_request) for an example. #[cfg(feature = "validate-request")] - fn validate(self, validate: T) -> ValidateRequestHeader { + fn validate(self, validate: T) -> ValidateRequestHeader + where + Self: Sized, + { ValidateRequestHeader::custom(self, validate) } } From b8b106d59641c9a991096ecf1792020e09307941 Mon Sep 17 00:00:00 2001 From: Guillermo Lloret Talavera Date: Sun, 26 Nov 2023 11:28:14 +0100 Subject: [PATCH 3/6] Add ServiceExt tests --- .vscode/settings.json | 11 ---- tower-http/src/add_extension.rs | 23 ++++++- tower-http/src/auth/add_authorization.rs | 42 ++++++++++++- .../src/auth/async_require_authorization.rs | 18 +++++- tower-http/src/catch_panic.rs | 21 ++++++- tower-http/src/compression/mod.rs | 36 ++++++++++- tower-http/src/decompression/mod.rs | 26 +++++++- tower-http/src/follow_redirect/mod.rs | 19 +++++- tower-http/src/metrics/in_flight_requests.rs | 30 ++++++++- tower-http/src/normalize_path.rs | 23 ++++++- tower-http/src/request_id.rs | 33 +++++++++- tower-http/src/sensitive_headers.rs | 61 ++++++++++++++++++- tower-http/src/service_ext.rs | 2 - tower-http/src/validate_request.rs | 18 +++++- 14 files changed, 328 insertions(+), 35 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index b54abc87..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "editor.detectIndentation": true, - "editor.insertSpaces": false, - "[rust]": { - "editor.tabSize": 2, - "editor.rulers": [ - 100 - ], - }, - "rust-analyzer.rustfmt.extraArgs": [], -} \ No newline at end of file diff --git a/tower-http/src/add_extension.rs b/tower-http/src/add_extension.rs index 095646df..67e176dd 100644 --- a/tower-http/src/add_extension.rs +++ b/tower-http/src/add_extension.rs @@ -138,10 +138,10 @@ where mod tests { #[allow(unused_imports)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::Response; use std::{convert::Infallible, sync::Arc}; - use tower::{service_fn, ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; struct State(i32); @@ -164,4 +164,23 @@ mod tests { assert_eq!(1, res); } + + #[tokio::test] + async fn basic_service_ext() { + let state = Arc::new(State(1)); + + let svc = service_fn(|req: Request| async move { + let state = req.extensions().get::>().unwrap(); + Ok::<_, Infallible>(Response::new(state.0)) + }) + .add_extension(state); + + let res = svc + .oneshot(Request::new(Body::empty())) + .await + .unwrap() + .into_body(); + + assert_eq!(1, res); + } } diff --git a/tower-http/src/auth/add_authorization.rs b/tower-http/src/auth/add_authorization.rs index 246c13b6..69ab459b 100644 --- a/tower-http/src/auth/add_authorization.rs +++ b/tower-http/src/auth/add_authorization.rs @@ -189,11 +189,11 @@ where #[cfg(test)] mod tests { use super::*; - use crate::test_helpers::Body; use crate::validate_request::ValidateRequestHeaderLayer; + use crate::{test_helpers::Body, ServiceExt}; use http::{Response, StatusCode}; use std::convert::Infallible; - use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; + use tower::{service_fn, BoxError, Service, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn basic() { @@ -216,6 +216,25 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } + #[tokio::test] + async fn basic_service_ext() { + // service that requires auth for all requests + let svc = service_fn(echo).require_basic_authorization("foo", "bar"); + + // make a client that adds auth + let mut client = AddAuthorization::basic(svc, "foo", "bar"); + + let res = client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + #[tokio::test] async fn token() { // service that requires auth for all requests @@ -237,6 +256,25 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } + #[tokio::test] + async fn token_service_ext() { + // service that requires auth for all requests + let svc = service_fn(echo).require_bearer_authorization("foo"); + + // make a client that adds auth + let mut client = AddAuthorization::bearer(svc, "foo"); + + let res = client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + #[tokio::test] async fn making_header_sensitive() { let svc = ServiceBuilder::new() diff --git a/tower-http/src/auth/async_require_authorization.rs b/tower-http/src/auth/async_require_authorization.rs index f086add2..59a543ec 100644 --- a/tower-http/src/auth/async_require_authorization.rs +++ b/tower-http/src/auth/async_require_authorization.rs @@ -308,10 +308,10 @@ where mod tests { #[allow(unused_imports)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use futures_util::future::BoxFuture; use http::{header, StatusCode}; - use tower::{BoxError, ServiceBuilder, ServiceExt}; + use tower::{service_fn, BoxError, ServiceBuilder, ServiceExt as TowerServiceExt}; #[derive(Clone, Copy)] struct MyAuth; @@ -383,6 +383,20 @@ mod tests { assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } + #[tokio::test] + async fn require_async_auth_401_service_ext() { + let mut service = service_fn(echo).async_require_authorization(MyAuth); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer deez") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } diff --git a/tower-http/src/catch_panic.rs b/tower-http/src/catch_panic.rs index 3f1c2279..394be57e 100644 --- a/tower-http/src/catch_panic.rs +++ b/tower-http/src/catch_panic.rs @@ -366,10 +366,10 @@ mod tests { #![allow(unreachable_code)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::Response; use std::convert::Infallible; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn panic_before_returning_future() { @@ -406,4 +406,21 @@ mod tests { let body = crate::test_helpers::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"Service panicked"); } + + #[tokio::test] + async fn panic_in_future_service_ext() { + let svc = service_fn(|_: Request| async { + panic!("future panic"); + Ok::<_, Infallible>(Response::new(Body::empty())) + }) + .catch_panic(); + + let req = Request::new(Body::empty()); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + let body = crate::test_helpers::to_bytes(res).await.unwrap(); + assert_eq!(&body[..], b"Service panicked"); + } } diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs index 897304c3..1eb9b570 100644 --- a/tower-http/src/compression/mod.rs +++ b/tower-http/src/compression/mod.rs @@ -92,7 +92,10 @@ mod tests { use crate::compression::predicate::SizeAbove; use super::*; - use crate::test_helpers::{Body, WithTrailers}; + use crate::{ + test_helpers::{Body, WithTrailers}, + ServiceExt, + }; use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder}; use flate2::read::GzDecoder; use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE}; @@ -103,7 +106,7 @@ mod tests { use std::sync::{Arc, RwLock}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_util::io::StreamReader; - use tower::{service_fn, Service, ServiceExt}; + use tower::{service_fn, Service, ServiceExt as TowerServiceExt}; // Compression filter allows every other request to be compressed #[derive(Clone)] @@ -148,6 +151,35 @@ mod tests { assert_eq!(trailers["foo"], "bar"); } + #[tokio::test] + async fn gzip_works_service_ext() { + let mut svc = service_fn(handle).compress_when(Always); + + // call the service + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // read the compressed body + let collected = res.into_body().collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let compressed_data = collected.to_bytes(); + + // decompress the body + // doing this with flate2 as that is much easier than async-compression and blocking during + // tests is fine + let mut decoder = GzDecoder::new(&compressed_data[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + assert_eq!(decompressed, "Hello, World!"); + + // trailers are maintained + assert_eq!(trailers["foo"], "bar"); + } + #[tokio::test] async fn zstd_works() { let svc = service_fn(handle); diff --git a/tower-http/src/decompression/mod.rs b/tower-http/src/decompression/mod.rs index 708df439..f59f56f3 100644 --- a/tower-http/src/decompression/mod.rs +++ b/tower-http/src/decompression/mod.rs @@ -113,13 +113,13 @@ mod tests { use std::io::Write; use super::*; - use crate::test_helpers::Body; use crate::{compression::Compression, test_helpers::WithTrailers}; + use crate::{test_helpers::Body, ServiceExt}; use flate2::write::GzEncoder; use http::Response; use http::{HeaderMap, HeaderName, Request}; use http_body_util::BodyExt; - use tower::{service_fn, Service, ServiceExt}; + use tower::{service_fn, Service, ServiceExt as TowerServiceExt}; #[tokio::test] async fn works() { @@ -143,6 +143,28 @@ mod tests { assert_eq!(trailers["foo"], "bar"); } + #[tokio::test] + async fn works_service_ext() { + let mut client = Compression::new(service_fn(handle)).decompress(); + + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = client.ready().await.unwrap().call(req).await.unwrap(); + + // read the body, it will be decompressed automatically + let body = res.into_body(); + let collected = body.collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let decompressed_data = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + + assert_eq!(decompressed_data, "Hello, World!"); + + // maintains trailers + assert_eq!(trailers["foo"], "bar"); + } + async fn handle(_req: Request) -> Result>, Infallible> { let mut trailers = HeaderMap::new(); trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap()); diff --git a/tower-http/src/follow_redirect/mod.rs b/tower-http/src/follow_redirect/mod.rs index 516fabf7..200582aa 100644 --- a/tower-http/src/follow_redirect/mod.rs +++ b/tower-http/src/follow_redirect/mod.rs @@ -388,10 +388,10 @@ fn resolve_uri(relative: &str, base: &Uri) -> Option { #[cfg(test)] mod tests { use super::{policy::*, *}; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::header::LOCATION; use std::convert::Infallible; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn follows() { @@ -411,6 +411,21 @@ mod tests { ); } + #[tokio::test] + async fn follows_service_ext() { + let svc = service_fn(handle).follow_redirect_with_policy(Action::Follow); + let req = Request::builder() + .uri("http://example.com/42") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(*res.body(), 0); + assert_eq!( + res.extensions().get::().unwrap().0, + "http://example.com/0" + ); + } + #[tokio::test] async fn stops() { let svc = ServiceBuilder::new() diff --git a/tower-http/src/metrics/in_flight_requests.rs b/tower-http/src/metrics/in_flight_requests.rs index dbb5e2ff..8968bf3a 100644 --- a/tower-http/src/metrics/in_flight_requests.rs +++ b/tower-http/src/metrics/in_flight_requests.rs @@ -289,9 +289,9 @@ where mod tests { #[allow(unused_imports)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::Request; - use tower::{BoxError, ServiceBuilder}; + use tower::{service_fn, BoxError, ServiceBuilder}; #[tokio::test] async fn basic() { @@ -321,6 +321,32 @@ mod tests { assert_eq!(counter.get(), 0); } + #[tokio::test] + async fn basic_service_ext() { + let counter = InFlightRequestsCounter::new(); + + let mut service = service_fn(echo).count_in_flight_requests(counter.clone()); + assert_eq!(counter.get(), 0); + + // driving service to ready shouldn't increment the counter + std::future::poll_fn(|cx| service.poll_ready(cx)) + .await + .unwrap(); + assert_eq!(counter.get(), 0); + + // creating the response future should increment the count + let response_future = service.call(Request::new(Body::empty())); + assert_eq!(counter.get(), 1); + + // count shouldn't decrement until the full body has been comsumed + let response = response_future.await.unwrap(); + assert_eq!(counter.get(), 1); + + let body = response.into_body(); + crate::test_helpers::to_bytes(body).await.unwrap(); + assert_eq!(counter.get(), 0); + } + async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } diff --git a/tower-http/src/normalize_path.rs b/tower-http/src/normalize_path.rs index efc7be52..4ad663da 100644 --- a/tower-http/src/normalize_path.rs +++ b/tower-http/src/normalize_path.rs @@ -140,8 +140,9 @@ fn normalize_trailing_slash(uri: &mut Uri) { #[cfg(test)] mod tests { use super::*; + use crate::ServiceExt; use std::convert::Infallible; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn works() { @@ -165,6 +166,26 @@ mod tests { assert_eq!(body, "/foo"); } + #[tokio::test] + async fn works_service_ext() { + async fn handle(request: Request<()>) -> Result, Infallible> { + Ok(Response::new(request.uri().to_string())) + } + + let mut svc = service_fn(handle).normalize_path(); + + let body = svc + .ready() + .await + .unwrap() + .call(Request::builder().uri("/foo/").body(()).unwrap()) + .await + .unwrap() + .into_body(); + + assert_eq!(body, "/foo"); + } + #[test] fn is_noop_if_no_trailing_slash() { let mut uri = "/foo".parse::().unwrap(); diff --git a/tower-http/src/request_id.rs b/tower-http/src/request_id.rs index 1db2d02a..2b580339 100644 --- a/tower-http/src/request_id.rs +++ b/tower-http/src/request_id.rs @@ -485,7 +485,7 @@ impl MakeRequestId for MakeRequestUuid { #[cfg(test)] mod tests { use crate::test_helpers::Body; - use crate::ServiceBuilderExt as _; + use crate::{ServiceBuilderExt as _, ServiceExt}; use http::Response; use std::{ convert::Infallible, @@ -494,7 +494,7 @@ mod tests { Arc, }, }; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[allow(unused_imports)] use super::*; @@ -529,6 +529,35 @@ mod tests { assert_eq!(res.extensions().get::().unwrap().0, "2"); } + #[tokio::test] + async fn basic_service_ext() { + let svc = service_fn(handler) + .propagate_x_request_id() + .set_x_request_id(Counter::default()); + + // header on response + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "0"); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "1"); + + // doesn't override if header is already there + let req = Request::builder() + .header("x-request-id", "foo") + .body(Body::empty()) + .unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "foo"); + + // extension propagated + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.extensions().get::().unwrap().0, "2"); + } + #[tokio::test] async fn other_middleware_setting_request_id() { let svc = ServiceBuilder::new() diff --git a/tower-http/src/sensitive_headers.rs b/tower-http/src/sensitive_headers.rs index 3bd081db..a1917515 100644 --- a/tower-http/src/sensitive_headers.rs +++ b/tower-http/src/sensitive_headers.rs @@ -382,8 +382,9 @@ where mod tests { #[allow(unused_imports)] use super::*; + use crate::ServiceExt; use http::header; - use tower::{ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn multiple_value_header() { @@ -445,4 +446,62 @@ mod tests { assert!(value.is_sensitive()) } } + + #[tokio::test] + async fn multiple_value_header_service_ext() { + async fn response_set_cookie(req: http::Request<()>) -> Result, ()> { + let mut iter = req.headers().get_all(header::COOKIE).iter().peekable(); + + assert!(iter.peek().is_some()); + + for value in iter { + assert!(value.is_sensitive()) + } + + let mut resp = http::Response::new(()); + resp.headers_mut().append( + header::CONTENT_TYPE, + http::HeaderValue::from_static("text/html"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-1"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-2"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-3"), + ); + Ok(resp) + } + + let mut service = service_fn(response_set_cookie) + .set_sensitive_request_headers(vec![header::COOKIE]) + .set_sensitive_response_headers(vec![header::SET_COOKIE]); + + let mut req = http::Request::new(()); + req.headers_mut() + .append(header::COOKIE, http::HeaderValue::from_static("cookie+1")); + req.headers_mut() + .append(header::COOKIE, http::HeaderValue::from_static("cookie+2")); + + let resp = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(!resp + .headers() + .get(header::CONTENT_TYPE) + .unwrap() + .is_sensitive()); + + let mut iter = resp.headers().get_all(header::SET_COOKIE).iter().peekable(); + + assert!(iter.peek().is_some()); + + for value in iter { + assert!(value.is_sensitive()) + } + } } diff --git a/tower-http/src/service_ext.rs b/tower-http/src/service_ext.rs index 8a29ca3f..4c58c1a3 100644 --- a/tower-http/src/service_ext.rs +++ b/tower-http/src/service_ext.rs @@ -1,5 +1,3 @@ -// todo: tests - #[cfg(feature = "add-extension")] use crate::add_extension::AddExtension; #[cfg(feature = "auth")] diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index 327266af..e7ccde18 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -416,9 +416,9 @@ where mod tests { #[allow(unused_imports)] use super::*; - use crate::test_helpers::Body; + use crate::{test_helpers::Body, ServiceExt}; use http::header; - use tower::{BoxError, ServiceBuilder, ServiceExt}; + use tower::{service_fn, BoxError, ServiceBuilder, ServiceExt as TowerServiceExt}; #[tokio::test] async fn valid_accept_header() { @@ -436,6 +436,20 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } + #[tokio::test] + async fn valid_accept_header_service_ext() { + let mut service = service_fn(echo).validate_accept_header("application/json"); + + let request = Request::get("/") + .header(header::ACCEPT, "application/json") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + #[tokio::test] async fn valid_accept_header_accept_all_json() { let mut service = ServiceBuilder::new() From 649b46468101b75d83412312819c9dbf0bae7cd5 Mon Sep 17 00:00:00 2001 From: Guillermo Lloret Talavera Date: Sun, 26 Nov 2023 11:51:42 +0100 Subject: [PATCH 4/6] Fix ServiceExt docs --- tower-http/src/service_ext.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tower-http/src/service_ext.rs b/tower-http/src/service_ext.rs index 4c58c1a3..8b4be6d3 100644 --- a/tower-http/src/service_ext.rs +++ b/tower-http/src/service_ext.rs @@ -107,7 +107,7 @@ pub trait ServiceExt: tower_service::Service { /// /// # Panics /// - /// Panics if the token is not a valid [`HeaderValue`]. + /// Panics if the token is not a valid [HeaderValue](http::{HeaderValue). /// /// See the [auth](crate::auth) for more details. #[cfg(feature = "auth")] @@ -276,7 +276,7 @@ pub trait ServiceExt: tower_service::Service { RequestDecompression::new(self) } - /// Creates a new middleware that retries requests with a [`Service`] to follow redirection responses. + /// Creates a new middleware that retries requests with a [`Service`](tower::Service) to follow redirection responses. /// /// See the [follow_redirect](crate::follow_redirect) for more details. #[cfg(feature = "follow-redirect")] @@ -287,8 +287,8 @@ pub trait ServiceExt: tower_service::Service { FollowRedirect::new(self) } - /// Creates a new middleware that retries requests with a [`Service`] to follow redirection responses - /// with the given redirection [`Policy`]. + /// Creates a new middleware that retries requests with a [`Service`](tower::Service) to follow redirection responses + /// with the given redirection [`Policy`](crate::follow_redirect::policy::Policy). /// /// See the [follow_redirect](crate::follow_redirect) for more details. #[cfg(feature = "follow-redirect")] @@ -370,7 +370,7 @@ pub trait ServiceExt: tower_service::Service { /// Creates a new middleware that propagate request ids from requests to responses. /// /// If the request contains a matching header that header will be applied to responses. If a - /// [`RequestId`] extension is also present it will be propagated as well. + /// [`RequestId`](crate::request_id::RequestId) extension is also present it will be propagated as well. /// /// See the [request_id](crate::request_id) for an example. #[cfg(feature = "request-id")] @@ -385,7 +385,7 @@ pub trait ServiceExt: tower_service::Service { /// using `x-request-id` as the header name. /// /// If the request contains a matching header that header will be applied to responses. If a - /// [`RequestId`] extension is also present it will be propagated as well. + /// [`RequestId`](crate::request_id::RequestId) extension is also present it will be propagated as well. /// /// See the [request_id](crate::request_id) for an example. #[cfg(feature = "request-id")] @@ -401,8 +401,8 @@ pub trait ServiceExt: tower_service::Service { /// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a /// header with the same name, then the header will be inserted. /// - /// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other - /// services can access it. + /// Additionally [`RequestId`](crate::request_id::RequestId) will be inserted into + /// the Request extensions so other services can access it. /// /// See the [request_id](crate::request_id) for an example. #[cfg(feature = "request-id")] @@ -420,8 +420,8 @@ pub trait ServiceExt: tower_service::Service { /// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a /// header with the same name, then the header will be inserted. /// - /// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other - /// services can access it. + /// Additionally [`RequestId`](crate::request_id::RequestId) will be inserted into + /// the Request extensions so other services can access it. /// /// See the [request_id](crate::request_id) for an example. #[cfg(feature = "request-id")] @@ -723,7 +723,7 @@ pub trait ServiceExt: tower_service::Service { /// /// # Panics /// - /// Panics if the token is not a valid [`HeaderValue`]. + /// Panics if the token is not a valid [`HeaderValue`](http::header::HeaderValue). /// /// See the [validate_request](crate::validate_request) for an example. #[cfg(feature = "validate-request")] From 9518f2efcd2dda344fc6e97adabb94183bad3287 Mon Sep 17 00:00:00 2001 From: Guillermo Lloret Talavera Date: Mon, 27 Nov 2023 13:36:04 +0100 Subject: [PATCH 5/6] Fix feature gates --- tower-http/src/service_ext.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tower-http/src/service_ext.rs b/tower-http/src/service_ext.rs index 8b4be6d3..59e3638c 100644 --- a/tower-http/src/service_ext.rs +++ b/tower-http/src/service_ext.rs @@ -1,5 +1,7 @@ #[cfg(feature = "add-extension")] use crate::add_extension::AddExtension; +#[cfg(all(feature = "validate-request", feature = "auth"))] +use crate::auth::require_authorization::{Basic, Bearer}; #[cfg(feature = "auth")] use crate::auth::{AddAuthorization, AsyncRequireAuthorization}; #[cfg(feature = "catch-panic")] @@ -45,10 +47,7 @@ use crate::set_header::{SetRequestHeader, SetResponseHeader}; #[cfg(feature = "set-status")] use crate::set_status::SetStatus; #[cfg(feature = "validate-request")] -use crate::{ - auth::require_authorization::{Basic, Bearer}, - validate_request::{AcceptHeader, ValidateRequestHeader}, -}; +use crate::validate_request::{AcceptHeader, ValidateRequestHeader}; #[cfg(feature = "trace")] use crate::{ classify::{GrpcErrorsAsFailures, MakeClassifier, ServerErrorsAsFailures, SharedClassifier}, @@ -212,7 +211,6 @@ pub trait ServiceExt: tower_service::Service { /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[cfg(feature = "cors")] - #[cfg(feature = "cors")] fn add_cors_permissive(self) -> Cors where Self: Sized, @@ -587,6 +585,7 @@ pub trait ServiceExt: tower_service::Service { /// Request Timeout` response will be sent. /// /// See the [timeout](crate::timeout) for an example. + #[cfg(feature = "timeout")] fn timeout(self, timeout: Duration) -> Timeout where Self: Sized, @@ -597,6 +596,7 @@ pub trait ServiceExt: tower_service::Service { /// Creates a new middleware that applies a timeout to request bodies. /// /// See the [timeout](crate::timeout) for an example. + #[cfg(feature = "timeout")] fn timeout_request_body(self, timeout: Duration) -> RequestBodyTimeout where Self: Sized, @@ -607,6 +607,7 @@ pub trait ServiceExt: tower_service::Service { /// Creates a new middleware that applies a timeout to response bodies. /// /// See the [timeout](crate::timeout) for an example. + #[cfg(feature = "timeout")] fn timeout_response_body(self, timeout: Duration) -> ResponseBodyTimeout where Self: Sized, @@ -703,7 +704,7 @@ pub trait ServiceExt: tower_service::Service { /// with this method. However use of HTTPS/TLS is not enforced by this middleware. /// /// See the [validate_request](crate::validate_request) for an example. - #[cfg(feature = "validate-request")] + #[cfg(all(feature = "validate-request", feature = "auth"))] fn validate_basic_authorization( self, username: &str, @@ -726,7 +727,7 @@ pub trait ServiceExt: tower_service::Service { /// Panics if the token is not a valid [`HeaderValue`](http::header::HeaderValue). /// /// See the [validate_request](crate::validate_request) for an example. - #[cfg(feature = "validate-request")] + #[cfg(all(feature = "validate-request", feature = "auth"))] fn validate_bearer_authorization( self, token: &str, From b3b05665a6ab39c81c4980a5572cd370961a0226 Mon Sep 17 00:00:00 2001 From: Guillermo Lloret Talavera Date: Mon, 27 Nov 2023 20:21:43 +0100 Subject: [PATCH 6/6] Document changes in CHANGELOG --- tower-http/CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index ddb718c9..0cc5e860 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -9,7 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added -- None. +- Add `ServiceExt` trait ([#410]) + +[#410]: https://github.com/tower-rs/tower-http/pull/410 ## Changed