|
4 | 4 | //! cargo run -p example-tls-graceful-shutdown
|
5 | 5 | //! ```
|
6 | 6 |
|
7 |
| -fn main() { |
8 |
| - // This example has not yet been updated to Hyper 1.0 |
| 7 | +use axum::extract::Host; |
| 8 | +use axum::handler::HandlerWithoutStateExt; |
| 9 | +use axum::http::{StatusCode, Uri}; |
| 10 | +use axum::response::Redirect; |
| 11 | +use axum::{extract::Request, routing::get, BoxError, Router}; |
| 12 | +use futures_util::{pin_mut, FutureExt}; |
| 13 | +use hyper::body::Incoming; |
| 14 | +use hyper_util::rt::{TokioExecutor, TokioIo}; |
| 15 | +use rustls_pemfile::{certs, pkcs8_private_keys}; |
| 16 | +use std::future::{Future, IntoFuture}; |
| 17 | +use std::sync::atomic::{AtomicU64, Ordering}; |
| 18 | +use std::time::Duration; |
| 19 | +use std::{ |
| 20 | + fs::File, |
| 21 | + io::BufReader, |
| 22 | + path::{Path, PathBuf}, |
| 23 | + sync::Arc, |
| 24 | +}; |
| 25 | +use tokio::net::TcpListener; |
| 26 | +use tokio::{select, signal}; |
| 27 | +use tokio_rustls::{ |
| 28 | + rustls::{Certificate, PrivateKey, ServerConfig}, |
| 29 | + TlsAcceptor, |
| 30 | +}; |
| 31 | +use tower_service::Service; |
| 32 | +use tracing::{error, info, warn}; |
| 33 | +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; |
| 34 | + |
| 35 | +// !!!!!!!!! WARNING !!!!!!!!!! |
| 36 | +// |
| 37 | +// The code only gracefully shutdowns connections/tasks that are managed by axum. |
| 38 | +// If inside your handler you spawn a task (with tokio::spawn) and return a response. This task will not be gracefully shutdown. |
| 39 | +// So it is up to you to track those tasks and await for them correctly before terminating |
| 40 | +// |
| 41 | +// !!!!!!!!! WARNING !!!!!!!!!! |
| 42 | +#[tokio::main] |
| 43 | +async fn main() { |
| 44 | + tracing_subscriber::registry() |
| 45 | + .with( |
| 46 | + tracing_subscriber::EnvFilter::try_from_default_env() |
| 47 | + .unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()), |
| 48 | + ) |
| 49 | + .with(tracing_subscriber::fmt::layer()) |
| 50 | + .init(); |
| 51 | + |
| 52 | + let rustls_config = rustls_server_config( |
| 53 | + PathBuf::from(env!("CARGO_MANIFEST_DIR")) |
| 54 | + .join("self_signed_certs") |
| 55 | + .join("key.pem"), |
| 56 | + PathBuf::from(env!("CARGO_MANIFEST_DIR")) |
| 57 | + .join("self_signed_certs") |
| 58 | + .join("cert.pem"), |
| 59 | + ); |
| 60 | + |
| 61 | + let app = Router::new().route("/", get(handler)); |
| 62 | + |
| 63 | + let nb_inflight_requests = Arc::new(AtomicU64::new(0)); |
| 64 | + let shutdown_signal = mk_shutdown_signal().fuse(); |
| 65 | + let tls_acceptor = TlsAcceptor::from(rustls_config); |
| 66 | + |
| 67 | + let ports = Ports { |
| 68 | + http: 3080, |
| 69 | + https: 3443, |
| 70 | + }; |
| 71 | + let bind = format!("[::1]:{}", ports.https); |
| 72 | + let tcp_listener = TcpListener::bind(&bind).await.unwrap(); |
| 73 | + info!( |
| 74 | + "HTTPS server listening on {bind}. To contact curl -k https://localhost:{}", |
| 75 | + ports.https |
| 76 | + ); |
| 77 | + tokio::spawn(redirect_http_to_https(ports, mk_shutdown_signal())); |
| 78 | + |
| 79 | + pin_mut!(shutdown_signal); |
| 80 | + loop { |
| 81 | + let tower_service = app.clone(); |
| 82 | + let tls_acceptor = tls_acceptor.clone(); |
| 83 | + |
| 84 | + // Wait for new tcp connection or shutdown signal |
| 85 | + let (cnx, addr) = select! { |
| 86 | + biased; |
| 87 | + |
| 88 | + _ = &mut shutdown_signal => { |
| 89 | + break; |
| 90 | + } |
| 91 | + |
| 92 | + cnx = tcp_listener.accept() => { |
| 93 | + let Ok(cnx) = cnx else { |
| 94 | + error!("error accepting connection"); |
| 95 | + break; |
| 96 | + }; |
| 97 | + nb_inflight_requests.fetch_add(1, Ordering::Relaxed); |
| 98 | + cnx |
| 99 | + } |
| 100 | + }; |
| 101 | + |
| 102 | + let nb_inflight_requests = nb_inflight_requests.clone(); |
| 103 | + tokio::spawn(async move { |
| 104 | + let _guard = scopeguard::guard((), |_| { |
| 105 | + nb_inflight_requests.fetch_sub(1, Ordering::Relaxed); |
| 106 | + }); |
| 107 | + |
| 108 | + // Wait for tls handshake to happen |
| 109 | + let Ok(stream) = tls_acceptor.accept(cnx).await else { |
| 110 | + error!("error during tls handshake connection from {}", addr); |
| 111 | + return; |
| 112 | + }; |
| 113 | + |
| 114 | + // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. |
| 115 | + // `TokioIo` converts between them. |
| 116 | + let stream = TokioIo::new(stream); |
| 117 | + |
| 118 | + // Hyper has also its own `Service` trait and doesn't use tower. We can use |
| 119 | + // `hyper::service::service_fn` to create a hyper `Service` that calls our app through |
| 120 | + // `tower::Service::call`. |
| 121 | + let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { |
| 122 | + // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas |
| 123 | + // tower's `Service` requires `&mut self`. |
| 124 | + // |
| 125 | + // We don't need to call `poll_ready` since `Router` is always ready. |
| 126 | + tower_service.clone().call(request) |
| 127 | + }); |
| 128 | + |
| 129 | + let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) |
| 130 | + .serve_connection_with_upgrades(stream, hyper_service) |
| 131 | + .await; |
| 132 | + |
| 133 | + if let Err(err) = ret { |
| 134 | + warn!("error serving connection from {}: {}", addr, err); |
| 135 | + } |
| 136 | + }); |
| 137 | + } |
| 138 | + |
| 139 | + drop(tls_acceptor); |
| 140 | + drop(tcp_listener); |
| 141 | + info!("Server is shutting down. Waiting for inflight requests to complete before terminating"); |
| 142 | + loop { |
| 143 | + let nb_inflights = nb_inflight_requests.load(Ordering::Relaxed); |
| 144 | + if nb_inflights == 0 { |
| 145 | + break; |
| 146 | + } |
| 147 | + info!("Server is shutting down. Waiting for {} inflight requests to complete before terminating", nb_inflights); |
| 148 | + tokio::time::sleep(Duration::from_secs(1)).await; |
| 149 | + } |
9 | 150 | }
|
10 | 151 |
|
11 |
| -//use axum::{ |
12 |
| -// extract::Host, |
13 |
| -// handler::HandlerWithoutStateExt, |
14 |
| -// http::{StatusCode, Uri}, |
15 |
| -// response::Redirect, |
16 |
| -// routing::get, |
17 |
| -// BoxError, Router, |
18 |
| -//}; |
19 |
| -//use axum_server::tls_rustls::RustlsConfig; |
20 |
| -//use std::{future::Future, net::SocketAddr, path::PathBuf, time::Duration}; |
21 |
| -//use tokio::signal; |
22 |
| -//use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; |
23 |
| - |
24 |
| -//#[derive(Clone, Copy)] |
25 |
| -//struct Ports { |
26 |
| -// http: u16, |
27 |
| -// https: u16, |
28 |
| -//} |
29 |
| - |
30 |
| -//#[tokio::main] |
31 |
| -//async fn main() { |
32 |
| -// tracing_subscriber::registry() |
33 |
| -// .with( |
34 |
| -// tracing_subscriber::EnvFilter::try_from_default_env() |
35 |
| -// .unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()), |
36 |
| -// ) |
37 |
| -// .with(tracing_subscriber::fmt::layer()) |
38 |
| -// .init(); |
39 |
| - |
40 |
| -// let ports = Ports { |
41 |
| -// http: 7878, |
42 |
| -// https: 3000, |
43 |
| -// }; |
44 |
| - |
45 |
| -// //Create a handle for our TLS server so the shutdown signal can all shutdown |
46 |
| -// let handle = axum_server::Handle::new(); |
47 |
| -// //save the future for easy shutting down of redirect server |
48 |
| -// let shutdown_future = shutdown_signal(handle.clone()); |
49 |
| - |
50 |
| -// // optional: spawn a second server to redirect http requests to this server |
51 |
| -// tokio::spawn(redirect_http_to_https(ports, shutdown_future)); |
52 |
| - |
53 |
| -// // configure certificate and private key used by https |
54 |
| -// let config = RustlsConfig::from_pem_file( |
55 |
| -// PathBuf::from(env!("CARGO_MANIFEST_DIR")) |
56 |
| -// .join("self_signed_certs") |
57 |
| -// .join("cert.pem"), |
58 |
| -// PathBuf::from(env!("CARGO_MANIFEST_DIR")) |
59 |
| -// .join("self_signed_certs") |
60 |
| -// .join("key.pem"), |
61 |
| -// ) |
62 |
| -// .await |
63 |
| -// .unwrap(); |
64 |
| - |
65 |
| -// let app = Router::new().route("/", get(handler)); |
66 |
| - |
67 |
| -// // run https server |
68 |
| -// let addr = SocketAddr::from(([127, 0, 0, 1], ports.https)); |
69 |
| -// tracing::debug!("listening on {addr}"); |
70 |
| -// axum_server::bind_rustls(addr, config) |
71 |
| -// .handle(handle) |
72 |
| -// .serve(app.into_make_service()) |
73 |
| -// .await |
74 |
| -// .unwrap(); |
75 |
| -//} |
76 |
| - |
77 |
| -//async fn shutdown_signal(handle: axum_server::Handle) { |
78 |
| -// let ctrl_c = async { |
79 |
| -// signal::ctrl_c() |
80 |
| -// .await |
81 |
| -// .expect("failed to install Ctrl+C handler"); |
82 |
| -// }; |
83 |
| - |
84 |
| -// #[cfg(unix)] |
85 |
| -// let terminate = async { |
86 |
| -// signal::unix::signal(signal::unix::SignalKind::terminate()) |
87 |
| -// .expect("failed to install signal handler") |
88 |
| -// .recv() |
89 |
| -// .await; |
90 |
| -// }; |
91 |
| - |
92 |
| -// #[cfg(not(unix))] |
93 |
| -// let terminate = std::future::pending::<()>(); |
94 |
| - |
95 |
| -// tokio::select! { |
96 |
| -// _ = ctrl_c => {}, |
97 |
| -// _ = terminate => {}, |
98 |
| -// } |
99 |
| - |
100 |
| -// tracing::info!("Received termination signal shutting down"); |
101 |
| -// handle.graceful_shutdown(Some(Duration::from_secs(10))); // 10 secs is how long docker will wait |
102 |
| -// // to force shutdown |
103 |
| -//} |
104 |
| - |
105 |
| -//async fn handler() -> &'static str { |
106 |
| -// "Hello, World!" |
107 |
| -//} |
108 |
| - |
109 |
| -//async fn redirect_http_to_https(ports: Ports, signal: impl Future<Output = ()>) { |
110 |
| -// fn make_https(host: String, uri: Uri, ports: Ports) -> Result<Uri, BoxError> { |
111 |
| -// let mut parts = uri.into_parts(); |
112 |
| - |
113 |
| -// parts.scheme = Some(axum::http::uri::Scheme::HTTPS); |
114 |
| - |
115 |
| -// if parts.path_and_query.is_none() { |
116 |
| -// parts.path_and_query = Some("/".parse().unwrap()); |
117 |
| -// } |
118 |
| - |
119 |
| -// let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string()); |
120 |
| -// parts.authority = Some(https_host.parse()?); |
121 |
| - |
122 |
| -// Ok(Uri::from_parts(parts)?) |
123 |
| -// } |
124 |
| - |
125 |
| -// let redirect = move |Host(host): Host, uri: Uri| async move { |
126 |
| -// match make_https(host, uri, ports) { |
127 |
| -// Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), |
128 |
| -// Err(error) => { |
129 |
| -// tracing::warn!(%error, "failed to convert URI to HTTPS"); |
130 |
| -// Err(StatusCode::BAD_REQUEST) |
131 |
| -// } |
132 |
| -// } |
133 |
| -// }; |
134 |
| - |
135 |
| -// let addr = SocketAddr::from(([127, 0, 0, 1], ports.http)); |
136 |
| -// //let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); |
137 |
| -// tracing::debug!("listening on {addr}"); |
138 |
| -// hyper::Server::bind(&addr) |
139 |
| -// .serve(redirect.into_make_service()) |
140 |
| -// .with_graceful_shutdown(signal) |
141 |
| -// .await |
142 |
| -// .unwrap(); |
143 |
| -//} |
| 152 | +async fn handler() -> &'static str { |
| 153 | + tokio::time::sleep(Duration::from_secs(5)).await; |
| 154 | + "Hello, World!" |
| 155 | +} |
| 156 | + |
| 157 | +fn rustls_server_config(key: impl AsRef<Path>, cert: impl AsRef<Path>) -> Arc<ServerConfig> { |
| 158 | + let mut key_reader = BufReader::new(File::open(key).unwrap()); |
| 159 | + let mut cert_reader = BufReader::new(File::open(cert).unwrap()); |
| 160 | + |
| 161 | + let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0)); |
| 162 | + let certs = certs(&mut cert_reader) |
| 163 | + .unwrap() |
| 164 | + .into_iter() |
| 165 | + .map(Certificate) |
| 166 | + .collect(); |
| 167 | + |
| 168 | + let mut config = ServerConfig::builder() |
| 169 | + .with_safe_defaults() |
| 170 | + .with_no_client_auth() |
| 171 | + .with_single_cert(certs, key) |
| 172 | + .expect("bad certificate/key"); |
| 173 | + |
| 174 | + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; |
| 175 | + |
| 176 | + Arc::new(config) |
| 177 | +} |
| 178 | + |
| 179 | +async fn mk_shutdown_signal() { |
| 180 | + #[cfg(unix)] |
| 181 | + let terminate = async { |
| 182 | + signal::unix::signal(signal::unix::SignalKind::terminate()) |
| 183 | + .expect("failed to install signal handler") |
| 184 | + .recv() |
| 185 | + .await; |
| 186 | + }; |
| 187 | + |
| 188 | + #[cfg(not(unix))] |
| 189 | + let terminate = std::future::pending::<()>(); |
| 190 | + |
| 191 | + select! { |
| 192 | + _ = signal::ctrl_c() => {}, |
| 193 | + _ = terminate => {}, |
| 194 | + } |
| 195 | + |
| 196 | + info!("Received termination signal shutting down"); |
| 197 | +} |
| 198 | + |
| 199 | +// Redirect HTTP to HTTPS |
| 200 | +#[derive(Clone, Copy)] |
| 201 | +struct Ports { |
| 202 | + http: u16, |
| 203 | + https: u16, |
| 204 | +} |
| 205 | + |
| 206 | +async fn redirect_http_to_https(ports: Ports, signal: impl Future<Output = ()>) { |
| 207 | + fn make_https(host: String, uri: Uri, ports: Ports) -> Result<Uri, BoxError> { |
| 208 | + let mut parts = uri.into_parts(); |
| 209 | + |
| 210 | + parts.scheme = Some(axum::http::uri::Scheme::HTTPS); |
| 211 | + |
| 212 | + if parts.path_and_query.is_none() { |
| 213 | + parts.path_and_query = Some("/".parse().unwrap()); |
| 214 | + } |
| 215 | + |
| 216 | + let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string()); |
| 217 | + parts.authority = Some(https_host.parse()?); |
| 218 | + |
| 219 | + Ok(Uri::from_parts(parts)?) |
| 220 | + } |
| 221 | + |
| 222 | + let redirect = move |Host(host): Host, uri: Uri| async move { |
| 223 | + match make_https(host, uri, ports) { |
| 224 | + Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), |
| 225 | + Err(error) => { |
| 226 | + warn!(%error, "failed to convert URI to HTTPS"); |
| 227 | + Err(StatusCode::BAD_REQUEST) |
| 228 | + } |
| 229 | + } |
| 230 | + }; |
| 231 | + |
| 232 | + let bind = format!("[::1]:{}", ports.http); |
| 233 | + let listener = TcpListener::bind(&bind).await.unwrap(); |
| 234 | + info!( |
| 235 | + "HTTP server listening on {bind}. To contact curl http://localhost:{}", |
| 236 | + ports.http |
| 237 | + ); |
| 238 | + let server = axum::serve(listener, redirect.into_make_service()).into_future(); |
| 239 | + |
| 240 | + select! { |
| 241 | + biased; |
| 242 | + |
| 243 | + _ = signal => {}, |
| 244 | + _ = server => {}, |
| 245 | + } |
| 246 | + |
| 247 | + info!("HTTP server shutdown"); |
| 248 | +} |
0 commit comments