Skip to content

Commit 5ed4f52

Browse files
committed
Update example tls-graceful-shutdown to axum 0.7
1 parent 3c7cf81 commit 5ed4f52

File tree

2 files changed

+248
-137
lines changed

2 files changed

+248
-137
lines changed

examples/tls-graceful-shutdown/Cargo.toml

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@ publish = false
66

77
[dependencies]
88
axum = { path = "../../axum" }
9-
axum-server = { version = "0.3", features = ["tls-rustls"] }
10-
hyper = { version = "0.14", features = ["full"] }
9+
futures-util = { version = "0.3", default-features = false }
10+
hyper = { version = "1.0.0", features = ["full"] }
11+
hyper-util = { version = "0.1" }
12+
rustls-pemfile = "1.0.4"
13+
scopeguard = "1.2.0"
1114
tokio = { version = "1", features = ["full"] }
15+
tokio-rustls = "0.24.1"
16+
tower = { version = "0.4", features = [] }
17+
tower-service = "0.3.2"
1218
tracing = "0.1"
1319
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

examples/tls-graceful-shutdown/src/main.rs

+240-135
Original file line numberDiff line numberDiff line change
@@ -4,140 +4,245 @@
44
//! cargo run -p example-tls-graceful-shutdown
55
//! ```
66
7-
fn main() {
8-
// This example has not yet been updated to Hyper 1.0
7+
use axum::extract::Host;
8+
use axum::handler::HandlerWithoutStateExt;
9+
use axum::http::{StatusCode, Uri};
10+
use axum::response::Redirect;
11+
use axum::{extract::Request, routing::get, BoxError, Router};
12+
use futures_util::{pin_mut, FutureExt};
13+
use hyper::body::Incoming;
14+
use hyper_util::rt::{TokioExecutor, TokioIo};
15+
use rustls_pemfile::{certs, pkcs8_private_keys};
16+
use std::future::{Future, IntoFuture};
17+
use std::sync::atomic::{AtomicU64, Ordering};
18+
use std::time::Duration;
19+
use std::{
20+
fs::File,
21+
io::BufReader,
22+
path::{Path, PathBuf},
23+
sync::Arc,
24+
};
25+
use tokio::net::TcpListener;
26+
use tokio::{select, signal};
27+
use tokio_rustls::{
28+
rustls::{Certificate, PrivateKey, ServerConfig},
29+
TlsAcceptor,
30+
};
31+
use tower_service::Service;
32+
use tracing::{error, info, warn};
33+
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
34+
35+
// !!!!!!!!! WARNING !!!!!!!!!!
36+
//
37+
// The code only gracefully shutdowns connections/tasks that are managed by axum.
38+
// If inside your handler you spawn a task (with tokio::spawn) and return a response. This task will not be gracefully shutdown.
39+
// So it is up to you to track those tasks and await for them correctly before terminating
40+
//
41+
// !!!!!!!!! WARNING !!!!!!!!!!
42+
#[tokio::main]
43+
async fn main() {
44+
tracing_subscriber::registry()
45+
.with(
46+
tracing_subscriber::EnvFilter::try_from_default_env()
47+
.unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()),
48+
)
49+
.with(tracing_subscriber::fmt::layer())
50+
.init();
51+
52+
let rustls_config = rustls_server_config(
53+
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
54+
.join("self_signed_certs")
55+
.join("key.pem"),
56+
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
57+
.join("self_signed_certs")
58+
.join("cert.pem"),
59+
);
60+
61+
let app = Router::new().route("/", get(handler));
62+
63+
let nb_inflight_requests = Arc::new(AtomicU64::new(0));
64+
let shutdown_signal = mk_shutdown_signal().fuse();
65+
let tls_acceptor = TlsAcceptor::from(rustls_config);
66+
67+
let ports = Ports {
68+
http: 3080,
69+
https: 3443,
70+
};
71+
let bind = format!("[::1]:{}", ports.https);
72+
let tcp_listener = TcpListener::bind(&bind).await.unwrap();
73+
info!(
74+
"HTTPS server listening on {bind}. To contact curl -k https://localhost:{}",
75+
ports.https
76+
);
77+
tokio::spawn(redirect_http_to_https(ports, mk_shutdown_signal()));
78+
79+
pin_mut!(shutdown_signal);
80+
loop {
81+
let tower_service = app.clone();
82+
let tls_acceptor = tls_acceptor.clone();
83+
84+
// Wait for new tcp connection or shutdown signal
85+
let (cnx, addr) = select! {
86+
biased;
87+
88+
_ = &mut shutdown_signal => {
89+
break;
90+
}
91+
92+
cnx = tcp_listener.accept() => {
93+
let Ok(cnx) = cnx else {
94+
error!("error accepting connection");
95+
break;
96+
};
97+
nb_inflight_requests.fetch_add(1, Ordering::Relaxed);
98+
cnx
99+
}
100+
};
101+
102+
let nb_inflight_requests = nb_inflight_requests.clone();
103+
tokio::spawn(async move {
104+
let _guard = scopeguard::guard((), |_| {
105+
nb_inflight_requests.fetch_sub(1, Ordering::Relaxed);
106+
});
107+
108+
// Wait for tls handshake to happen
109+
let Ok(stream) = tls_acceptor.accept(cnx).await else {
110+
error!("error during tls handshake connection from {}", addr);
111+
return;
112+
};
113+
114+
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
115+
// `TokioIo` converts between them.
116+
let stream = TokioIo::new(stream);
117+
118+
// Hyper has also its own `Service` trait and doesn't use tower. We can use
119+
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
120+
// `tower::Service::call`.
121+
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
122+
// We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
123+
// tower's `Service` requires `&mut self`.
124+
//
125+
// We don't need to call `poll_ready` since `Router` is always ready.
126+
tower_service.clone().call(request)
127+
});
128+
129+
let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
130+
.serve_connection_with_upgrades(stream, hyper_service)
131+
.await;
132+
133+
if let Err(err) = ret {
134+
warn!("error serving connection from {}: {}", addr, err);
135+
}
136+
});
137+
}
138+
139+
drop(tls_acceptor);
140+
drop(tcp_listener);
141+
info!("Server is shutting down. Waiting for inflight requests to complete before terminating");
142+
loop {
143+
let nb_inflights = nb_inflight_requests.load(Ordering::Relaxed);
144+
if nb_inflights == 0 {
145+
break;
146+
}
147+
info!("Server is shutting down. Waiting for {} inflight requests to complete before terminating", nb_inflights);
148+
tokio::time::sleep(Duration::from_secs(1)).await;
149+
}
9150
}
10151

11-
//use axum::{
12-
// extract::Host,
13-
// handler::HandlerWithoutStateExt,
14-
// http::{StatusCode, Uri},
15-
// response::Redirect,
16-
// routing::get,
17-
// BoxError, Router,
18-
//};
19-
//use axum_server::tls_rustls::RustlsConfig;
20-
//use std::{future::Future, net::SocketAddr, path::PathBuf, time::Duration};
21-
//use tokio::signal;
22-
//use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
23-
24-
//#[derive(Clone, Copy)]
25-
//struct Ports {
26-
// http: u16,
27-
// https: u16,
28-
//}
29-
30-
//#[tokio::main]
31-
//async fn main() {
32-
// tracing_subscriber::registry()
33-
// .with(
34-
// tracing_subscriber::EnvFilter::try_from_default_env()
35-
// .unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()),
36-
// )
37-
// .with(tracing_subscriber::fmt::layer())
38-
// .init();
39-
40-
// let ports = Ports {
41-
// http: 7878,
42-
// https: 3000,
43-
// };
44-
45-
// //Create a handle for our TLS server so the shutdown signal can all shutdown
46-
// let handle = axum_server::Handle::new();
47-
// //save the future for easy shutting down of redirect server
48-
// let shutdown_future = shutdown_signal(handle.clone());
49-
50-
// // optional: spawn a second server to redirect http requests to this server
51-
// tokio::spawn(redirect_http_to_https(ports, shutdown_future));
52-
53-
// // configure certificate and private key used by https
54-
// let config = RustlsConfig::from_pem_file(
55-
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
56-
// .join("self_signed_certs")
57-
// .join("cert.pem"),
58-
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
59-
// .join("self_signed_certs")
60-
// .join("key.pem"),
61-
// )
62-
// .await
63-
// .unwrap();
64-
65-
// let app = Router::new().route("/", get(handler));
66-
67-
// // run https server
68-
// let addr = SocketAddr::from(([127, 0, 0, 1], ports.https));
69-
// tracing::debug!("listening on {addr}");
70-
// axum_server::bind_rustls(addr, config)
71-
// .handle(handle)
72-
// .serve(app.into_make_service())
73-
// .await
74-
// .unwrap();
75-
//}
76-
77-
//async fn shutdown_signal(handle: axum_server::Handle) {
78-
// let ctrl_c = async {
79-
// signal::ctrl_c()
80-
// .await
81-
// .expect("failed to install Ctrl+C handler");
82-
// };
83-
84-
// #[cfg(unix)]
85-
// let terminate = async {
86-
// signal::unix::signal(signal::unix::SignalKind::terminate())
87-
// .expect("failed to install signal handler")
88-
// .recv()
89-
// .await;
90-
// };
91-
92-
// #[cfg(not(unix))]
93-
// let terminate = std::future::pending::<()>();
94-
95-
// tokio::select! {
96-
// _ = ctrl_c => {},
97-
// _ = terminate => {},
98-
// }
99-
100-
// tracing::info!("Received termination signal shutting down");
101-
// handle.graceful_shutdown(Some(Duration::from_secs(10))); // 10 secs is how long docker will wait
102-
// // to force shutdown
103-
//}
104-
105-
//async fn handler() -> &'static str {
106-
// "Hello, World!"
107-
//}
108-
109-
//async fn redirect_http_to_https(ports: Ports, signal: impl Future<Output = ()>) {
110-
// fn make_https(host: String, uri: Uri, ports: Ports) -> Result<Uri, BoxError> {
111-
// let mut parts = uri.into_parts();
112-
113-
// parts.scheme = Some(axum::http::uri::Scheme::HTTPS);
114-
115-
// if parts.path_and_query.is_none() {
116-
// parts.path_and_query = Some("/".parse().unwrap());
117-
// }
118-
119-
// let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string());
120-
// parts.authority = Some(https_host.parse()?);
121-
122-
// Ok(Uri::from_parts(parts)?)
123-
// }
124-
125-
// let redirect = move |Host(host): Host, uri: Uri| async move {
126-
// match make_https(host, uri, ports) {
127-
// Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
128-
// Err(error) => {
129-
// tracing::warn!(%error, "failed to convert URI to HTTPS");
130-
// Err(StatusCode::BAD_REQUEST)
131-
// }
132-
// }
133-
// };
134-
135-
// let addr = SocketAddr::from(([127, 0, 0, 1], ports.http));
136-
// //let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
137-
// tracing::debug!("listening on {addr}");
138-
// hyper::Server::bind(&addr)
139-
// .serve(redirect.into_make_service())
140-
// .with_graceful_shutdown(signal)
141-
// .await
142-
// .unwrap();
143-
//}
152+
async fn handler() -> &'static str {
153+
tokio::time::sleep(Duration::from_secs(5)).await;
154+
"Hello, World!"
155+
}
156+
157+
fn rustls_server_config(key: impl AsRef<Path>, cert: impl AsRef<Path>) -> Arc<ServerConfig> {
158+
let mut key_reader = BufReader::new(File::open(key).unwrap());
159+
let mut cert_reader = BufReader::new(File::open(cert).unwrap());
160+
161+
let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0));
162+
let certs = certs(&mut cert_reader)
163+
.unwrap()
164+
.into_iter()
165+
.map(Certificate)
166+
.collect();
167+
168+
let mut config = ServerConfig::builder()
169+
.with_safe_defaults()
170+
.with_no_client_auth()
171+
.with_single_cert(certs, key)
172+
.expect("bad certificate/key");
173+
174+
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
175+
176+
Arc::new(config)
177+
}
178+
179+
async fn mk_shutdown_signal() {
180+
#[cfg(unix)]
181+
let terminate = async {
182+
signal::unix::signal(signal::unix::SignalKind::terminate())
183+
.expect("failed to install signal handler")
184+
.recv()
185+
.await;
186+
};
187+
188+
#[cfg(not(unix))]
189+
let terminate = std::future::pending::<()>();
190+
191+
select! {
192+
_ = signal::ctrl_c() => {},
193+
_ = terminate => {},
194+
}
195+
196+
info!("Received termination signal shutting down");
197+
}
198+
199+
// Redirect HTTP to HTTPS
200+
#[derive(Clone, Copy)]
201+
struct Ports {
202+
http: u16,
203+
https: u16,
204+
}
205+
206+
async fn redirect_http_to_https(ports: Ports, signal: impl Future<Output = ()>) {
207+
fn make_https(host: String, uri: Uri, ports: Ports) -> Result<Uri, BoxError> {
208+
let mut parts = uri.into_parts();
209+
210+
parts.scheme = Some(axum::http::uri::Scheme::HTTPS);
211+
212+
if parts.path_and_query.is_none() {
213+
parts.path_and_query = Some("/".parse().unwrap());
214+
}
215+
216+
let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string());
217+
parts.authority = Some(https_host.parse()?);
218+
219+
Ok(Uri::from_parts(parts)?)
220+
}
221+
222+
let redirect = move |Host(host): Host, uri: Uri| async move {
223+
match make_https(host, uri, ports) {
224+
Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
225+
Err(error) => {
226+
warn!(%error, "failed to convert URI to HTTPS");
227+
Err(StatusCode::BAD_REQUEST)
228+
}
229+
}
230+
};
231+
232+
let bind = format!("[::1]:{}", ports.http);
233+
let listener = TcpListener::bind(&bind).await.unwrap();
234+
info!(
235+
"HTTP server listening on {bind}. To contact curl http://localhost:{}",
236+
ports.http
237+
);
238+
let server = axum::serve(listener, redirect.into_make_service()).into_future();
239+
240+
select! {
241+
biased;
242+
243+
_ = signal => {},
244+
_ = server => {},
245+
}
246+
247+
info!("HTTP server shutdown");
248+
}

0 commit comments

Comments
 (0)