Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ServiceExt trait #410

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Added

- None.
- Add `ServiceExt` trait ([#410])

[#410]: https://github.com/tower-rs/tower-http/pull/410

## Changed

Expand Down
23 changes: 21 additions & 2 deletions tower-http/src/add_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ where
mod tests {
#[allow(unused_imports)]
use super::*;
use crate::test_helpers::Body;
use crate::{test_helpers::Body, ServiceExt};
use http::Response;
use std::{convert::Infallible, sync::Arc};
use tower::{service_fn, ServiceBuilder, ServiceExt};
use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt};

struct State(i32);

Expand All @@ -164,4 +164,23 @@ mod tests {

assert_eq!(1, res);
}

#[tokio::test]
async fn basic_service_ext() {
let state = Arc::new(State(1));

let svc = service_fn(|req: Request<Body>| async move {
let state = req.extensions().get::<Arc<State>>().unwrap();
Ok::<_, Infallible>(Response::new(state.0))
})
.add_extension(state);

let res = svc
.oneshot(Request::new(Body::empty()))
.await
.unwrap()
.into_body();

assert_eq!(1, res);
}
}
42 changes: 40 additions & 2 deletions tower-http/src/auth/add_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::Body;
use crate::validate_request::ValidateRequestHeaderLayer;
use crate::{test_helpers::Body, ServiceExt};
use http::{Response, StatusCode};
use std::convert::Infallible;
use tower::{BoxError, Service, ServiceBuilder, ServiceExt};
use tower::{service_fn, BoxError, Service, ServiceBuilder, ServiceExt as TowerServiceExt};

#[tokio::test]
async fn basic() {
Expand All @@ -216,6 +216,25 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn basic_service_ext() {
// service that requires auth for all requests
let svc = service_fn(echo).require_basic_authorization("foo", "bar");

// make a client that adds auth
let mut client = AddAuthorization::basic(svc, "foo", "bar");

let res = client
.ready()
.await
.unwrap()
.call(Request::new(Body::empty()))
.await
.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn token() {
// service that requires auth for all requests
Expand All @@ -237,6 +256,25 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn token_service_ext() {
// service that requires auth for all requests
let svc = service_fn(echo).require_bearer_authorization("foo");

// make a client that adds auth
let mut client = AddAuthorization::bearer(svc, "foo");

let res = client
.ready()
.await
.unwrap()
.call(Request::new(Body::empty()))
.await
.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn making_header_sensitive() {
let svc = ServiceBuilder::new()
Expand Down
18 changes: 16 additions & 2 deletions tower-http/src/auth/async_require_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,10 @@ where
mod tests {
#[allow(unused_imports)]
use super::*;
use crate::test_helpers::Body;
use crate::{test_helpers::Body, ServiceExt};
use futures_util::future::BoxFuture;
use http::{header, StatusCode};
use tower::{BoxError, ServiceBuilder, ServiceExt};
use tower::{service_fn, BoxError, ServiceBuilder, ServiceExt as TowerServiceExt};

#[derive(Clone, Copy)]
struct MyAuth;
Expand Down Expand Up @@ -383,6 +383,20 @@ mod tests {
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn require_async_auth_401_service_ext() {
let mut service = service_fn(echo).async_require_authorization(MyAuth);

let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer deez")
.body(Body::empty())
.unwrap();

let res = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}

async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
Expand Down
21 changes: 19 additions & 2 deletions tower-http/src/catch_panic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,10 @@ mod tests {
#![allow(unreachable_code)]

use super::*;
use crate::test_helpers::Body;
use crate::{test_helpers::Body, ServiceExt};
use http::Response;
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt};
use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt};

#[tokio::test]
async fn panic_before_returning_future() {
Expand Down Expand Up @@ -406,4 +406,21 @@ mod tests {
let body = crate::test_helpers::to_bytes(res).await.unwrap();
assert_eq!(&body[..], b"Service panicked");
}

#[tokio::test]
async fn panic_in_future_service_ext() {
let svc = service_fn(|_: Request<Body>| async {
panic!("future panic");
Ok::<_, Infallible>(Response::new(Body::empty()))
})
.catch_panic();

let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();

assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = crate::test_helpers::to_bytes(res).await.unwrap();
assert_eq!(&body[..], b"Service panicked");
}
}
36 changes: 34 additions & 2 deletions tower-http/src/compression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ mod tests {
use crate::compression::predicate::SizeAbove;

use super::*;
use crate::test_helpers::{Body, WithTrailers};
use crate::{
test_helpers::{Body, WithTrailers},
ServiceExt,
};
use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
use flate2::read::GzDecoder;
use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
Expand All @@ -103,7 +106,7 @@ mod tests {
use std::sync::{Arc, RwLock};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_util::io::StreamReader;
use tower::{service_fn, Service, ServiceExt};
use tower::{service_fn, Service, ServiceExt as TowerServiceExt};

// Compression filter allows every other request to be compressed
#[derive(Clone)]
Expand Down Expand Up @@ -148,6 +151,35 @@ mod tests {
assert_eq!(trailers["foo"], "bar");
}

#[tokio::test]
async fn gzip_works_service_ext() {
let mut svc = service_fn(handle).compress_when(Always);

// call the service
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();

// read the compressed body
let collected = res.into_body().collect().await.unwrap();
let trailers = collected.trailers().cloned().unwrap();
let compressed_data = collected.to_bytes();

// decompress the body
// doing this with flate2 as that is much easier than async-compression and blocking during
// tests is fine
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();

assert_eq!(decompressed, "Hello, World!");

// trailers are maintained
assert_eq!(trailers["foo"], "bar");
}

#[tokio::test]
async fn zstd_works() {
let svc = service_fn(handle);
Expand Down
26 changes: 24 additions & 2 deletions tower-http/src/decompression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ mod tests {
use std::io::Write;

use super::*;
use crate::test_helpers::Body;
use crate::{compression::Compression, test_helpers::WithTrailers};
use crate::{test_helpers::Body, ServiceExt};
use flate2::write::GzEncoder;
use http::Response;
use http::{HeaderMap, HeaderName, Request};
use http_body_util::BodyExt;
use tower::{service_fn, Service, ServiceExt};
use tower::{service_fn, Service, ServiceExt as TowerServiceExt};

#[tokio::test]
async fn works() {
Expand All @@ -143,6 +143,28 @@ mod tests {
assert_eq!(trailers["foo"], "bar");
}

#[tokio::test]
async fn works_service_ext() {
let mut client = Compression::new(service_fn(handle)).decompress();

let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = client.ready().await.unwrap().call(req).await.unwrap();

// read the body, it will be decompressed automatically
let body = res.into_body();
let collected = body.collect().await.unwrap();
let trailers = collected.trailers().cloned().unwrap();
let decompressed_data = String::from_utf8(collected.to_bytes().to_vec()).unwrap();

assert_eq!(decompressed_data, "Hello, World!");

// maintains trailers
assert_eq!(trailers["foo"], "bar");
}

async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> {
let mut trailers = HeaderMap::new();
trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap());
Expand Down
19 changes: 17 additions & 2 deletions tower-http/src/follow_redirect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
#[cfg(test)]
mod tests {
use super::{policy::*, *};
use crate::test_helpers::Body;
use crate::{test_helpers::Body, ServiceExt};
use http::header::LOCATION;
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt};
use tower::{service_fn, ServiceBuilder, ServiceExt as TowerServiceExt};

#[tokio::test]
async fn follows() {
Expand All @@ -411,6 +411,21 @@ mod tests {
);
}

#[tokio::test]
async fn follows_service_ext() {
let svc = service_fn(handle).follow_redirect_with_policy(Action::Follow);
let req = Request::builder()
.uri("http://example.com/42")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(*res.body(), 0);
assert_eq!(
res.extensions().get::<RequestUri>().unwrap().0,
"http://example.com/0"
);
}

#[tokio::test]
async fn stops() {
let svc = ServiceBuilder::new()
Expand Down
3 changes: 3 additions & 0 deletions tower-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ pub mod validate_request;

pub mod body;

mod service_ext;
pub use self::service_ext::ServiceExt;

/// The latency unit used to report latencies by middleware.
#[non_exhaustive]
#[derive(Copy, Clone, Debug)]
Expand Down
30 changes: 28 additions & 2 deletions tower-http/src/metrics/in_flight_requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ where
mod tests {
#[allow(unused_imports)]
use super::*;
use crate::test_helpers::Body;
use crate::{test_helpers::Body, ServiceExt};
use http::Request;
use tower::{BoxError, ServiceBuilder};
use tower::{service_fn, BoxError, ServiceBuilder};

#[tokio::test]
async fn basic() {
Expand Down Expand Up @@ -321,6 +321,32 @@ mod tests {
assert_eq!(counter.get(), 0);
}

#[tokio::test]
async fn basic_service_ext() {
let counter = InFlightRequestsCounter::new();

let mut service = service_fn(echo).count_in_flight_requests(counter.clone());
assert_eq!(counter.get(), 0);

// driving service to ready shouldn't increment the counter
std::future::poll_fn(|cx| service.poll_ready(cx))
.await
.unwrap();
assert_eq!(counter.get(), 0);

// creating the response future should increment the count
let response_future = service.call(Request::new(Body::empty()));
assert_eq!(counter.get(), 1);

// count shouldn't decrement until the full body has been comsumed
let response = response_future.await.unwrap();
assert_eq!(counter.get(), 1);

let body = response.into_body();
crate::test_helpers::to_bytes(body).await.unwrap();
assert_eq!(counter.get(), 0);
}

async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
Expand Down
Loading