From 6c0780443ce7d29699df5b76b7d8d2acbcfd64ad Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Wed, 11 Sep 2024 10:50:47 +0000 Subject: [PATCH 1/5] chore: split TLS detection logic Signed-off-by: Zahari Dichev --- linkerd/app/admin/src/stack.rs | 8 +- linkerd/app/inbound/src/detect.rs | 22 +- linkerd/app/inbound/src/direct.rs | 22 +- linkerd/app/inbound/src/metrics/error.rs | 2 +- linkerd/app/src/tap.rs | 12 +- linkerd/meshtls/tests/util.rs | 57 +++-- linkerd/proxy/tap/src/accept.rs | 2 +- linkerd/tls/src/detect_sni.rs | 211 ++++++++++++++++++ .../{server => detect_sni}/client_hello.rs | 0 .../curl-example-com-client-hello.bin | Bin .../testdata/example-com-client-hello.bin | Bin linkerd/tls/src/lib.rs | 2 + linkerd/tls/src/server.rs | 150 +------------ 13 files changed, 269 insertions(+), 219 deletions(-) create mode 100644 linkerd/tls/src/detect_sni.rs rename linkerd/tls/src/{server => detect_sni}/client_hello.rs (100%) rename linkerd/tls/src/{server => detect_sni}/testdata/curl-example-com-client-hello.bin (100%) rename linkerd/tls/src/{server => detect_sni}/testdata/example-com-client-hello.bin (100%) diff --git a/linkerd/app/admin/src/stack.rs b/linkerd/app/admin/src/stack.rs index 7f36f5d392..49e4991ced 100644 --- a/linkerd/app/admin/src/stack.rs +++ b/linkerd/app/admin/src/stack.rs @@ -193,6 +193,7 @@ impl Config { .push(tls::NewDetectTls::::layer(TlsParams { identity, })) + .push(tls::NewDetectSNI::layer(DETECT_TIMEOUT.into())) .arc_new_tcp() .into_inner(); @@ -279,13 +280,6 @@ impl Param for Permitted { // === TlsParams === -impl ExtractParam for TlsParams { - #[inline] - fn extract_param(&self, _: &T) -> tls::server::Timeout { - tls::server::Timeout(DETECT_TIMEOUT) - } -} - impl ExtractParam for TlsParams { #[inline] fn extract_param(&self, _: &T) -> identity::Server { diff --git a/linkerd/app/inbound/src/detect.rs b/linkerd/app/inbound/src/detect.rs index e77b85b1cc..a492c7cb41 100644 --- a/linkerd/app/inbound/src/detect.rs +++ b/linkerd/app/inbound/src/detect.rs @@ -52,12 +52,9 @@ struct Detect { struct ConfigureHttpDetect; #[derive(Clone)] -struct TlsParams { - timeout: tls::server::Timeout, - identity: identity::Server, -} +struct TlsParams(identity::Server); -type TlsIo = tls::server::Io>, I>; +type TlsIo = tls::server::Io>, I>; // === impl Inbound === @@ -244,11 +241,9 @@ impl Inbound>> { ) .arc_new_tcp() .push(tls::NewDetectTls::::layer( - TlsParams { - timeout: tls::server::Timeout(detect_timeout), - identity: rt.identity.server(), - }, + TlsParams(rt.identity.server()), )) + .push(tls::NewDetectSNI::layer(detect_timeout.into())) .arc_new_tcp() .push_switch( // Check the policy for this port and check whether @@ -420,17 +415,10 @@ impl svc::Param for Http { // === TlsParams === -impl svc::ExtractParam for TlsParams { - #[inline] - fn extract_param(&self, _: &T) -> tls::server::Timeout { - self.timeout - } -} - impl svc::ExtractParam for TlsParams { #[inline] fn extract_param(&self, _: &T) -> identity::Server { - self.identity.clone() + self.0.clone() } } diff --git a/linkerd/app/inbound/src/direct.rs b/linkerd/app/inbound/src/direct.rs index 5025556ffd..ba8edb3ea2 100644 --- a/linkerd/app/inbound/src/direct.rs +++ b/linkerd/app/inbound/src/direct.rs @@ -70,15 +70,12 @@ pub struct ClientInfo { pub local_addr: OrigDstAddr, } -type TlsIo = tls::server::Io>, I>; +type TlsIo = tls::server::Io>, I>; type FwdIo = SensorIo>>; pub type GatewayIo = FwdIo; #[derive(Clone)] -struct TlsParams { - timeout: tls::server::Timeout, - identity: identity::Server, -} +struct TlsParams(identity::Server); impl Inbound { /// Builds a stack that handles connections that target the proxy's inbound port @@ -220,11 +217,9 @@ impl Inbound { }) .push(svc::ArcNewService::layer()) .push(tls::NewDetectTls::::layer( - TlsParams { - timeout: tls::server::Timeout(detect_timeout), - identity, - }, + TlsParams(identity), )) + .push(tls::NewDetectSNI::layer(detect_timeout.into())) .arc_new_tcp() }) } @@ -453,17 +448,10 @@ impl From for Error { // === TlsParams === -impl ExtractParam for TlsParams { - #[inline] - fn extract_param(&self, _: &T) -> tls::server::Timeout { - self.timeout - } -} - impl ExtractParam for TlsParams { #[inline] fn extract_param(&self, _: &T) -> identity::Server { - self.identity.clone() + self.0.clone() } } diff --git a/linkerd/app/inbound/src/metrics/error.rs b/linkerd/app/inbound/src/metrics/error.rs index b83313117a..63db374b9c 100644 --- a/linkerd/app/inbound/src/metrics/error.rs +++ b/linkerd/app/inbound/src/metrics/error.rs @@ -42,7 +42,7 @@ impl ErrorKind { Some(ErrorKind::FailFast) } else if err.is::() { Some(ErrorKind::Io) - } else if err.is::() { + } else if err.is::() { Some(ErrorKind::TlsDetectTimeout) } else if err.is::() { Some(ErrorKind::GatewayDomainInvalid) diff --git a/linkerd/app/src/tap.rs b/linkerd/app/src/tap.rs index c9f5d6cee5..3d07a422a1 100644 --- a/linkerd/app/src/tap.rs +++ b/linkerd/app/src/tap.rs @@ -9,7 +9,7 @@ use linkerd_app_core::{ transport::{addrs::AddrPair, listen::Bind, ClientAddr, Local, Remote, ServerAddr}, Error, }; -use std::{collections::HashSet, pin::Pin}; +use std::{collections::HashSet, pin::Pin, time::Duration}; use tower::util::{service_fn, ServiceExt}; #[derive(Clone, Debug)] @@ -38,6 +38,8 @@ struct TlsParams { identity: identity::Server, } +const DETECT_TIMEOUT: Duration = Duration::from_secs(1); + impl Config { pub fn build( self, @@ -83,6 +85,7 @@ impl Config { .push(tls::NewDetectTls::::layer( TlsParams { identity }, )) + .push(tls::NewDetectSNI::layer(DETECT_TIMEOUT.into())) .check_new_service::() .into_inner(); @@ -109,13 +112,6 @@ impl Tap { // === TlsParams === -impl ExtractParam for TlsParams { - #[inline] - fn extract_param(&self, _: &T) -> tls::server::Timeout { - tls::server::Timeout(std::time::Duration::from_secs(1)) - } -} - impl ExtractParam for TlsParams { #[inline] fn extract_param(&self, _: &T) -> identity::Server { diff --git a/linkerd/meshtls/tests/util.rs b/linkerd/meshtls/tests/util.rs index c294ecfecc..5859e43585 100644 --- a/linkerd/meshtls/tests/util.rs +++ b/linkerd/meshtls/tests/util.rs @@ -155,7 +155,7 @@ pub async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match(mode: type ServerConn = ( (tls::ConditionalServerTls, T), - io::EitherIo>, tls::server::DetectIo>, + io::EitherIo>, tls::detect_sni::DetectIo>, ); fn load( @@ -228,31 +228,34 @@ where // Saves the result of every connection. let (sender, receiver) = mpsc::channel::>(); - let detect = tls::NewDetectTls::::new( - ServerParams { - identity: server_tls, - }, - move |meta: (tls::ConditionalServerTls, Addrs)| { - let server = server.clone(); - let sender = sender.clone(); - let tls = meta.0.clone().map(Into::into); - service_fn(move |conn| { + let detect = tls::NewDetectSNI::new( + tls::NewDetectTls::::new( + ServerParams { + identity: server_tls, + }, + move |meta: (tls::ConditionalServerTls, Addrs)| { let server = server.clone(); let sender = sender.clone(); - let tls = Some(tls.clone()); - let future = server((meta.clone(), conn)); - Box::pin( - async move { - let result = future.await; - sender - .send(Transported { tls, result }) - .expect("send result"); - Ok::<(), Infallible>(()) - } - .instrument(tracing::info_span!("test_svc")), - ) - }) - }, + let tls = meta.0.clone().map(Into::into); + service_fn(move |conn| { + let server = server.clone(); + let sender = sender.clone(); + let tls = Some(tls.clone()); + let future = server((meta.clone(), conn)); + Box::pin( + async move { + let result = future.await; + sender + .send(Transported { tls, result }) + .expect("send result"); + Ok::<(), Infallible>(()) + } + .instrument(tracing::info_span!("test_svc")), + ) + }) + }, + ), + Duration::from_secs(10).into(), ); let (listen_addr, listen) = BindTcp::default().bind(&Server).expect("must bind"); @@ -409,12 +412,6 @@ impl Param for Server { // === impl ServerParams === -impl ExtractParam for ServerParams { - fn extract_param(&self, _: &T) -> tls::server::Timeout { - tls::server::Timeout(Duration::from_secs(10)) - } -} - impl ExtractParam for ServerParams { fn extract_param(&self, _: &T) -> meshtls::Server { self.identity.clone() diff --git a/linkerd/proxy/tap/src/accept.rs b/linkerd/proxy/tap/src/accept.rs index 57f9411958..de4ee7f76b 100644 --- a/linkerd/proxy/tap/src/accept.rs +++ b/linkerd/proxy/tap/src/accept.rs @@ -24,7 +24,7 @@ pub struct AcceptPermittedClients { type Connection = ( (tls::ConditionalServerTls, T), - io::EitherIo>, tls::server::DetectIo>, + io::EitherIo>, tls::detect_sni::DetectIo>, ); pub type ServeFuture = Pin> + Send + 'static>>; diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs new file mode 100644 index 0000000000..66f034742b --- /dev/null +++ b/linkerd/tls/src/detect_sni.rs @@ -0,0 +1,211 @@ +mod client_hello; + +use crate::ServerName; +use bytes::BytesMut; +use futures::prelude::*; +use linkerd_error::Error; +use linkerd_io::{self as io, AsyncReadExt, EitherIo, PrefixedIo}; +use linkerd_stack::{layer, NewService, Service, ServiceExt}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use thiserror::Error; +use tokio::time::{self, Duration}; +use tracing::{debug, trace, warn}; + +pub type DetectIo = EitherIo>; + +#[derive(Clone, Debug)] +pub struct NewDetectSNI { + inner: N, + timeout: Timeout, +} + +#[derive(Copy, Clone, Debug)] +pub struct Timeout(pub Duration); + +#[derive(Clone, Debug, Error)] +#[error("SNI detection timed out")] +pub struct DetectSniTimeoutError(()); + +/// Attempts to detect an SNI from the client hello of a TLS session +#[derive(Clone, Debug)] +pub struct DetectSNI { + target: T, + inner: N, + timeout: Timeout, +} + +// The initial peek buffer is fairly small so that we can avoid allocating more +// data then we need; but it is large enough to hold the ~300B ClientHello sent +// by proxies. +const PEEK_CAPACITY: usize = 512; + +// A larger fallback buffer is allocated onto the heap if the initial peek +// buffer is insufficient. This is the same value used in HTTP detection. +const BUFFER_CAPACITY: usize = 8192; + +impl NewDetectSNI { + pub fn new(inner: N, timeout: Timeout) -> Self { + Self { inner, timeout } + } + + pub fn layer(timeout: Timeout) -> impl layer::Layer + Clone { + layer::mk(move |inner| Self::new(inner, timeout)) + } +} + +impl NewService for NewDetectSNI +where + N: Clone, +{ + type Service = DetectSNI; + + fn new_service(&self, target: T) -> Self::Service { + DetectSNI { + target, + inner: self.inner.clone(), + timeout: self.timeout, + } + } +} + +impl Service for DetectSNI +where + T: Clone + Send + 'static, + I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static, + N: NewService<(T, Option), Service = S> + Clone + Send + 'static, + S: Service> + Send, + S::Error: Into, + S::Future: Send, +{ + type Response = S::Response; + type Error = Error; + type Future = Pin> + Send + 'static>>; + + #[inline] + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, io: I) -> Self::Future { + let target = self.target.clone(); + let new_accept = self.inner.clone(); + + // Detect the SNI from a ClientHello (or timeout). + let Timeout(timeout) = self.timeout; + let detect = time::timeout(timeout, detect_sni(io)); + Box::pin(async move { + let (sni, io) = detect.await.map_err(|_| DetectSniTimeoutError(()))??; + + println!("detected SNI: {:?}", sni); + let svc = new_accept.new_service((target, sni)); + svc.oneshot(io).await.map_err(Into::into) + }) + } +} + +/// Peek or buffer the provided stream to determine an SNI value. +async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> +where + I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin, +{ + // First, try to use MSG_PEEK to read the SNI from the TLS ClientHello. We + // use a heap-allocated buffer to avoid creating a large `Future` (since we + // need to hold the buffer across an await). + // + // Anecdotally, the ClientHello sent by Linkerd proxies is <300B. So a ~500B + // byte buffer is more than enough. + let mut buf = BytesMut::with_capacity(PEEK_CAPACITY); + let sz = io.peek(&mut buf).await?; + debug!(sz, "Peeked bytes from TCP stream"); + // Peek may return 0 bytes if the socket is not peekable. + if sz > 0 { + match client_hello::parse_sni(buf.as_ref()) { + Ok(sni) => { + return Ok((sni, EitherIo::Left(io))); + } + + Err(client_hello::Incomplete) => {} + } + } + + // Peeking didn't return enough data, so instead we'll allocate more + // capacity and try reading data from the socket. + debug!("Attempting to buffer TLS ClientHello after incomplete peek"); + let mut buf = BytesMut::with_capacity(BUFFER_CAPACITY); + debug!(buf.capacity = %buf.capacity(), "Reading bytes from TCP stream"); + while io.read_buf(&mut buf).await? != 0 { + debug!(buf.len = %buf.len(), "Read bytes from TCP stream"); + match client_hello::parse_sni(buf.as_ref()) { + Ok(sni) => { + return Ok((sni, EitherIo::Right(PrefixedIo::new(buf.freeze(), io)))); + } + + Err(client_hello::Incomplete) => { + if buf.capacity() == 0 { + // If we can't buffer an entire TLS ClientHello, it + // almost definitely wasn't initiated by another proxy, + // at least. + warn!("Buffer insufficient for TLS ClientHello"); + break; + } + // Continue if there is still buffer capacity. + } + } + } + + trace!("Could not read TLS ClientHello via buffering"); + let io = EitherIo::Right(PrefixedIo::new(buf.freeze(), io)); + Ok((None, io)) +} + +impl From for Timeout { + fn from(d: Duration) -> Self { + Self(d) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use linkerd_io::AsyncWriteExt; + + #[tokio::test(flavor = "current_thread")] + async fn detect_buffered() { + let _trace = linkerd_tracing::test::trace_init(); + + let (mut client_io, server_io) = linkerd_io::duplex(1024); + let input = include_bytes!("detect_sni/testdata/curl-example-com-client-hello.bin"); + let len = input.len(); + let client_task = tokio::spawn(async move { + client_io + .write_all(input) + .await + .expect("Write must succeed"); + }); + + let (sni, io) = detect_sni(server_io) + .await + .expect("SNI detection must not fail"); + + assert_eq!(sni, Some(ServerName("example.com".parse().unwrap()))); + + match io { + EitherIo::Left(_) => panic!("Detected IO should be buffered"), + EitherIo::Right(io) => assert_eq!(io.prefix().len(), len, "All data must be buffered"), + } + + client_task.await.expect("Client must not fail"); + } +} + +#[cfg(fuzzing)] +pub mod fuzz_logic { + use super::*; + + pub fn fuzz_entry(input: &[u8]) { + let _ = client_hello::parse_sni(input); + } +} diff --git a/linkerd/tls/src/server/client_hello.rs b/linkerd/tls/src/detect_sni/client_hello.rs similarity index 100% rename from linkerd/tls/src/server/client_hello.rs rename to linkerd/tls/src/detect_sni/client_hello.rs diff --git a/linkerd/tls/src/server/testdata/curl-example-com-client-hello.bin b/linkerd/tls/src/detect_sni/testdata/curl-example-com-client-hello.bin similarity index 100% rename from linkerd/tls/src/server/testdata/curl-example-com-client-hello.bin rename to linkerd/tls/src/detect_sni/testdata/curl-example-com-client-hello.bin diff --git a/linkerd/tls/src/server/testdata/example-com-client-hello.bin b/linkerd/tls/src/detect_sni/testdata/example-com-client-hello.bin similarity index 100% rename from linkerd/tls/src/server/testdata/example-com-client-hello.bin rename to linkerd/tls/src/detect_sni/testdata/example-com-client-hello.bin diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index 0e54d86442..2efa2af6d4 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -2,10 +2,12 @@ #![forbid(unsafe_code)] pub mod client; +pub mod detect_sni; pub mod server; pub use self::{ client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId}, + detect_sni::NewDetectSNI, server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls}, }; diff --git a/linkerd/tls/src/server.rs b/linkerd/tls/src/server.rs index 04862401f9..c8b8c5ada8 100644 --- a/linkerd/tls/src/server.rs +++ b/linkerd/tls/src/server.rs @@ -1,12 +1,9 @@ -mod client_hello; - -use crate::{NegotiatedProtocol, ServerName}; -use bytes::BytesMut; +use crate::{detect_sni::DetectIo, NegotiatedProtocol, ServerName}; use futures::prelude::*; use linkerd_conditional::Conditional; use linkerd_error::Error; use linkerd_identity as id; -use linkerd_io::{self as io, AsyncReadExt, EitherIo, PrefixedIo}; +use linkerd_io::{self as io, EitherIo}; use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Param, Service, ServiceExt}; use std::{ fmt, @@ -14,9 +11,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use thiserror::Error; -use tokio::time::{self, Duration}; -use tracing::{debug, trace, warn}; +use tracing::{debug, trace}; /// Describes the authenticated identity of a remote client. #[derive(Clone, Debug, Eq, PartialEq, Hash)] @@ -54,8 +49,6 @@ pub enum NoServerTls { /// Indicates whether TLS was established on an accepted connection. pub type ConditionalServerTls = Conditional; -pub type DetectIo = EitherIo>; - pub type Io = EitherIo>; #[derive(Clone, Debug)] @@ -65,31 +58,15 @@ pub struct NewDetectTls { _local_identity: std::marker::PhantomData L>, } -#[derive(Copy, Clone, Debug)] -pub struct Timeout(pub Duration); - -#[derive(Clone, Debug, Error)] -#[error("TLS detection timed out")] -pub struct ServerTlsTimeoutError(()); - #[derive(Clone, Debug)] pub struct DetectTls { target: T, local_identity: L, - timeout: Timeout, params: P, inner: N, + sni: Option, } -// The initial peek buffer is fairly small so that we can avoid allocating more -// data then we need; but it is large enough to hold the ~300B ClientHello sent -// by proxies. -const PEEK_CAPACITY: usize = 512; - -// A larger fallback buffer is allocated onto the heap if the initial peek -// buffer is insufficient. This is the same value used in HTTP detection. -const BUFFER_CAPACITY: usize = 8192; - impl NewDetectTls { pub fn new(params: P, inner: N) -> Self { Self { @@ -107,27 +84,27 @@ impl NewDetectTls { } } -impl NewService for NewDetectTls +impl NewService<(T, Option)> for NewDetectTls where - P: ExtractParam + ExtractParam + Clone, + P: ExtractParam + Clone, N: Clone, { type Service = DetectTls; - fn new_service(&self, target: T) -> Self::Service { - let timeout = self.params.extract_param(&target); + fn new_service(&self, t: (T, Option)) -> Self::Service { + let (target, sni) = t; let local_identity = self.params.extract_param(&target); DetectTls { target, local_identity, - timeout, + sni, params: self.params.clone(), inner: self.inner.clone(), } } } -impl Service for DetectTls +impl Service> for DetectTls where I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin + 'static, T: Clone + Send + 'static, @@ -150,19 +127,14 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, io: I) -> Self::Future { + fn call(&mut self, io: DetectIo) -> Self::Future { let target = self.target.clone(); let params = self.params.clone(); let new_accept = self.inner.clone(); - + let sni = self.sni.clone(); let tls = self.local_identity.clone(); - // Detect the SNI from a ClientHello (or timeout). - let Timeout(timeout) = self.timeout; - let detect = time::timeout(timeout, detect_sni(io)); Box::pin(async move { - let (sni, io) = detect.await.map_err(|_| ServerTlsTimeoutError(()))??; - let local_server_name = tls.param(); let (peer, io) = match sni { // If we detected an SNI matching this proxy, terminate TLS. @@ -191,61 +163,6 @@ where } } -/// Peek or buffer the provided stream to determine an SNI value. -async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> -where - I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin, -{ - // First, try to use MSG_PEEK to read the SNI from the TLS ClientHello. We - // use a heap-allocated buffer to avoid creating a large `Future` (since we - // need to hold the buffer across an await). - // - // Anecdotally, the ClientHello sent by Linkerd proxies is <300B. So a ~500B - // byte buffer is more than enough. - let mut buf = BytesMut::with_capacity(PEEK_CAPACITY); - let sz = io.peek(&mut buf).await?; - debug!(sz, "Peeked bytes from TCP stream"); - // Peek may return 0 bytes if the socket is not peekable. - if sz > 0 { - match client_hello::parse_sni(buf.as_ref()) { - Ok(sni) => { - return Ok((sni, EitherIo::Left(io))); - } - - Err(client_hello::Incomplete) => {} - } - } - - // Peeking didn't return enough data, so instead we'll allocate more - // capacity and try reading data from the socket. - debug!("Attempting to buffer TLS ClientHello after incomplete peek"); - let mut buf = BytesMut::with_capacity(BUFFER_CAPACITY); - debug!(buf.capacity = %buf.capacity(), "Reading bytes from TCP stream"); - while io.read_buf(&mut buf).await? != 0 { - debug!(buf.len = %buf.len(), "Read bytes from TCP stream"); - match client_hello::parse_sni(buf.as_ref()) { - Ok(sni) => { - return Ok((sni, EitherIo::Right(PrefixedIo::new(buf.freeze(), io)))); - } - - Err(client_hello::Incomplete) => { - if buf.capacity() == 0 { - // If we can't buffer an entire TLS ClientHello, it - // almost definitely wasn't initiated by another proxy, - // at least. - warn!("Buffer insufficient for TLS ClientHello"); - break; - } - // Continue if there is still buffer capacity. - } - } - } - - trace!("Could not read TLS ClientHello via buffering"); - let io = EitherIo::Right(PrefixedIo::new(buf.freeze(), io)); - Ok((None, io)) -} - // === impl ClientId === impl From for ClientId { @@ -304,46 +221,3 @@ impl ServerTls { } } } - -#[cfg(test)] -mod tests { - use super::*; - use linkerd_io::AsyncWriteExt; - - #[tokio::test(flavor = "current_thread")] - async fn detect_buffered() { - let _trace = linkerd_tracing::test::trace_init(); - - let (mut client_io, server_io) = linkerd_io::duplex(1024); - let input = include_bytes!("server/testdata/curl-example-com-client-hello.bin"); - let len = input.len(); - let client_task = tokio::spawn(async move { - client_io - .write_all(input) - .await - .expect("Write must succeed"); - }); - - let (sni, io) = detect_sni(server_io) - .await - .expect("SNI detection must not fail"); - - assert_eq!(sni, Some(ServerName("example.com".parse().unwrap()))); - - match io { - EitherIo::Left(_) => panic!("Detected IO should be buffered"), - EitherIo::Right(io) => assert_eq!(io.prefix().len(), len, "All data must be buffered"), - } - - client_task.await.expect("Client must not fail"); - } -} - -#[cfg(fuzzing)] -pub mod fuzz_logic { - use super::*; - - pub fn fuzz_entry(input: &[u8]) { - let _ = client_hello::parse_sni(input); - } -} From d909be7d3ea66c96e84ade0e3f55e99cc6b66aae Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Thu, 12 Sep 2024 07:31:38 +0000 Subject: [PATCH 2/5] remove println Signed-off-by: Zahari Dichev --- linkerd/tls/src/detect_sni.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs index 66f034742b..95730200a5 100644 --- a/linkerd/tls/src/detect_sni.rs +++ b/linkerd/tls/src/detect_sni.rs @@ -99,7 +99,7 @@ where Box::pin(async move { let (sni, io) = detect.await.map_err(|_| DetectSniTimeoutError(()))??; - println!("detected SNI: {:?}", sni); + debug!("detected SNI: {:?}", sni); let svc = new_accept.new_service((target, sni)); svc.oneshot(io).await.map_err(Into::into) }) From cfcfa0ac3d2efa658b1966ebcd9574795fbc178b Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Fri, 13 Sep 2024 10:54:18 +0000 Subject: [PATCH 3/5] Revert "remove println" This reverts commit d909be7d3ea66c96e84ade0e3f55e99cc6b66aae. Signed-off-by: Zahari Dichev --- linkerd/tls/src/detect_sni.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs index 95730200a5..66f034742b 100644 --- a/linkerd/tls/src/detect_sni.rs +++ b/linkerd/tls/src/detect_sni.rs @@ -99,7 +99,7 @@ where Box::pin(async move { let (sni, io) = detect.await.map_err(|_| DetectSniTimeoutError(()))??; - debug!("detected SNI: {:?}", sni); + println!("detected SNI: {:?}", sni); let svc = new_accept.new_service((target, sni)); svc.oneshot(io).await.map_err(Into::into) }) From d354cc384274915e0eac80c89c5a9fbb2f2c2ebe Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Fri, 13 Sep 2024 10:54:22 +0000 Subject: [PATCH 4/5] Revert "chore: split TLS detection logic" This reverts commit 6c0780443ce7d29699df5b76b7d8d2acbcfd64ad. Signed-off-by: Zahari Dichev --- linkerd/app/admin/src/stack.rs | 8 +- linkerd/app/inbound/src/detect.rs | 22 +- linkerd/app/inbound/src/direct.rs | 22 +- linkerd/app/inbound/src/metrics/error.rs | 2 +- linkerd/app/src/tap.rs | 12 +- linkerd/meshtls/tests/util.rs | 57 ++--- linkerd/proxy/tap/src/accept.rs | 2 +- linkerd/tls/src/detect_sni.rs | 211 ------------------ linkerd/tls/src/lib.rs | 2 - linkerd/tls/src/server.rs | 150 ++++++++++++- .../{detect_sni => server}/client_hello.rs | 0 .../curl-example-com-client-hello.bin | Bin .../testdata/example-com-client-hello.bin | Bin 13 files changed, 219 insertions(+), 269 deletions(-) delete mode 100644 linkerd/tls/src/detect_sni.rs rename linkerd/tls/src/{detect_sni => server}/client_hello.rs (100%) rename linkerd/tls/src/{detect_sni => server}/testdata/curl-example-com-client-hello.bin (100%) rename linkerd/tls/src/{detect_sni => server}/testdata/example-com-client-hello.bin (100%) diff --git a/linkerd/app/admin/src/stack.rs b/linkerd/app/admin/src/stack.rs index 49e4991ced..7f36f5d392 100644 --- a/linkerd/app/admin/src/stack.rs +++ b/linkerd/app/admin/src/stack.rs @@ -193,7 +193,6 @@ impl Config { .push(tls::NewDetectTls::::layer(TlsParams { identity, })) - .push(tls::NewDetectSNI::layer(DETECT_TIMEOUT.into())) .arc_new_tcp() .into_inner(); @@ -280,6 +279,13 @@ impl Param for Permitted { // === TlsParams === +impl ExtractParam for TlsParams { + #[inline] + fn extract_param(&self, _: &T) -> tls::server::Timeout { + tls::server::Timeout(DETECT_TIMEOUT) + } +} + impl ExtractParam for TlsParams { #[inline] fn extract_param(&self, _: &T) -> identity::Server { diff --git a/linkerd/app/inbound/src/detect.rs b/linkerd/app/inbound/src/detect.rs index a492c7cb41..e77b85b1cc 100644 --- a/linkerd/app/inbound/src/detect.rs +++ b/linkerd/app/inbound/src/detect.rs @@ -52,9 +52,12 @@ struct Detect { struct ConfigureHttpDetect; #[derive(Clone)] -struct TlsParams(identity::Server); +struct TlsParams { + timeout: tls::server::Timeout, + identity: identity::Server, +} -type TlsIo = tls::server::Io>, I>; +type TlsIo = tls::server::Io>, I>; // === impl Inbound === @@ -241,9 +244,11 @@ impl Inbound>> { ) .arc_new_tcp() .push(tls::NewDetectTls::::layer( - TlsParams(rt.identity.server()), + TlsParams { + timeout: tls::server::Timeout(detect_timeout), + identity: rt.identity.server(), + }, )) - .push(tls::NewDetectSNI::layer(detect_timeout.into())) .arc_new_tcp() .push_switch( // Check the policy for this port and check whether @@ -415,10 +420,17 @@ impl svc::Param for Http { // === TlsParams === +impl svc::ExtractParam for TlsParams { + #[inline] + fn extract_param(&self, _: &T) -> tls::server::Timeout { + self.timeout + } +} + impl svc::ExtractParam for TlsParams { #[inline] fn extract_param(&self, _: &T) -> identity::Server { - self.0.clone() + self.identity.clone() } } diff --git a/linkerd/app/inbound/src/direct.rs b/linkerd/app/inbound/src/direct.rs index ba8edb3ea2..5025556ffd 100644 --- a/linkerd/app/inbound/src/direct.rs +++ b/linkerd/app/inbound/src/direct.rs @@ -70,12 +70,15 @@ pub struct ClientInfo { pub local_addr: OrigDstAddr, } -type TlsIo = tls::server::Io>, I>; +type TlsIo = tls::server::Io>, I>; type FwdIo = SensorIo>>; pub type GatewayIo = FwdIo; #[derive(Clone)] -struct TlsParams(identity::Server); +struct TlsParams { + timeout: tls::server::Timeout, + identity: identity::Server, +} impl Inbound { /// Builds a stack that handles connections that target the proxy's inbound port @@ -217,9 +220,11 @@ impl Inbound { }) .push(svc::ArcNewService::layer()) .push(tls::NewDetectTls::::layer( - TlsParams(identity), + TlsParams { + timeout: tls::server::Timeout(detect_timeout), + identity, + }, )) - .push(tls::NewDetectSNI::layer(detect_timeout.into())) .arc_new_tcp() }) } @@ -448,10 +453,17 @@ impl From for Error { // === TlsParams === +impl ExtractParam for TlsParams { + #[inline] + fn extract_param(&self, _: &T) -> tls::server::Timeout { + self.timeout + } +} + impl ExtractParam for TlsParams { #[inline] fn extract_param(&self, _: &T) -> identity::Server { - self.0.clone() + self.identity.clone() } } diff --git a/linkerd/app/inbound/src/metrics/error.rs b/linkerd/app/inbound/src/metrics/error.rs index 63db374b9c..b83313117a 100644 --- a/linkerd/app/inbound/src/metrics/error.rs +++ b/linkerd/app/inbound/src/metrics/error.rs @@ -42,7 +42,7 @@ impl ErrorKind { Some(ErrorKind::FailFast) } else if err.is::() { Some(ErrorKind::Io) - } else if err.is::() { + } else if err.is::() { Some(ErrorKind::TlsDetectTimeout) } else if err.is::() { Some(ErrorKind::GatewayDomainInvalid) diff --git a/linkerd/app/src/tap.rs b/linkerd/app/src/tap.rs index 3d07a422a1..c9f5d6cee5 100644 --- a/linkerd/app/src/tap.rs +++ b/linkerd/app/src/tap.rs @@ -9,7 +9,7 @@ use linkerd_app_core::{ transport::{addrs::AddrPair, listen::Bind, ClientAddr, Local, Remote, ServerAddr}, Error, }; -use std::{collections::HashSet, pin::Pin, time::Duration}; +use std::{collections::HashSet, pin::Pin}; use tower::util::{service_fn, ServiceExt}; #[derive(Clone, Debug)] @@ -38,8 +38,6 @@ struct TlsParams { identity: identity::Server, } -const DETECT_TIMEOUT: Duration = Duration::from_secs(1); - impl Config { pub fn build( self, @@ -85,7 +83,6 @@ impl Config { .push(tls::NewDetectTls::::layer( TlsParams { identity }, )) - .push(tls::NewDetectSNI::layer(DETECT_TIMEOUT.into())) .check_new_service::() .into_inner(); @@ -112,6 +109,13 @@ impl Tap { // === TlsParams === +impl ExtractParam for TlsParams { + #[inline] + fn extract_param(&self, _: &T) -> tls::server::Timeout { + tls::server::Timeout(std::time::Duration::from_secs(1)) + } +} + impl ExtractParam for TlsParams { #[inline] fn extract_param(&self, _: &T) -> identity::Server { diff --git a/linkerd/meshtls/tests/util.rs b/linkerd/meshtls/tests/util.rs index 5859e43585..c294ecfecc 100644 --- a/linkerd/meshtls/tests/util.rs +++ b/linkerd/meshtls/tests/util.rs @@ -155,7 +155,7 @@ pub async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match(mode: type ServerConn = ( (tls::ConditionalServerTls, T), - io::EitherIo>, tls::detect_sni::DetectIo>, + io::EitherIo>, tls::server::DetectIo>, ); fn load( @@ -228,34 +228,31 @@ where // Saves the result of every connection. let (sender, receiver) = mpsc::channel::>(); - let detect = tls::NewDetectSNI::new( - tls::NewDetectTls::::new( - ServerParams { - identity: server_tls, - }, - move |meta: (tls::ConditionalServerTls, Addrs)| { + let detect = tls::NewDetectTls::::new( + ServerParams { + identity: server_tls, + }, + move |meta: (tls::ConditionalServerTls, Addrs)| { + let server = server.clone(); + let sender = sender.clone(); + let tls = meta.0.clone().map(Into::into); + service_fn(move |conn| { let server = server.clone(); let sender = sender.clone(); - let tls = meta.0.clone().map(Into::into); - service_fn(move |conn| { - let server = server.clone(); - let sender = sender.clone(); - let tls = Some(tls.clone()); - let future = server((meta.clone(), conn)); - Box::pin( - async move { - let result = future.await; - sender - .send(Transported { tls, result }) - .expect("send result"); - Ok::<(), Infallible>(()) - } - .instrument(tracing::info_span!("test_svc")), - ) - }) - }, - ), - Duration::from_secs(10).into(), + let tls = Some(tls.clone()); + let future = server((meta.clone(), conn)); + Box::pin( + async move { + let result = future.await; + sender + .send(Transported { tls, result }) + .expect("send result"); + Ok::<(), Infallible>(()) + } + .instrument(tracing::info_span!("test_svc")), + ) + }) + }, ); let (listen_addr, listen) = BindTcp::default().bind(&Server).expect("must bind"); @@ -412,6 +409,12 @@ impl Param for Server { // === impl ServerParams === +impl ExtractParam for ServerParams { + fn extract_param(&self, _: &T) -> tls::server::Timeout { + tls::server::Timeout(Duration::from_secs(10)) + } +} + impl ExtractParam for ServerParams { fn extract_param(&self, _: &T) -> meshtls::Server { self.identity.clone() diff --git a/linkerd/proxy/tap/src/accept.rs b/linkerd/proxy/tap/src/accept.rs index de4ee7f76b..57f9411958 100644 --- a/linkerd/proxy/tap/src/accept.rs +++ b/linkerd/proxy/tap/src/accept.rs @@ -24,7 +24,7 @@ pub struct AcceptPermittedClients { type Connection = ( (tls::ConditionalServerTls, T), - io::EitherIo>, tls::detect_sni::DetectIo>, + io::EitherIo>, tls::server::DetectIo>, ); pub type ServeFuture = Pin> + Send + 'static>>; diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs deleted file mode 100644 index 66f034742b..0000000000 --- a/linkerd/tls/src/detect_sni.rs +++ /dev/null @@ -1,211 +0,0 @@ -mod client_hello; - -use crate::ServerName; -use bytes::BytesMut; -use futures::prelude::*; -use linkerd_error::Error; -use linkerd_io::{self as io, AsyncReadExt, EitherIo, PrefixedIo}; -use linkerd_stack::{layer, NewService, Service, ServiceExt}; -use std::{ - pin::Pin, - task::{Context, Poll}, -}; -use thiserror::Error; -use tokio::time::{self, Duration}; -use tracing::{debug, trace, warn}; - -pub type DetectIo = EitherIo>; - -#[derive(Clone, Debug)] -pub struct NewDetectSNI { - inner: N, - timeout: Timeout, -} - -#[derive(Copy, Clone, Debug)] -pub struct Timeout(pub Duration); - -#[derive(Clone, Debug, Error)] -#[error("SNI detection timed out")] -pub struct DetectSniTimeoutError(()); - -/// Attempts to detect an SNI from the client hello of a TLS session -#[derive(Clone, Debug)] -pub struct DetectSNI { - target: T, - inner: N, - timeout: Timeout, -} - -// The initial peek buffer is fairly small so that we can avoid allocating more -// data then we need; but it is large enough to hold the ~300B ClientHello sent -// by proxies. -const PEEK_CAPACITY: usize = 512; - -// A larger fallback buffer is allocated onto the heap if the initial peek -// buffer is insufficient. This is the same value used in HTTP detection. -const BUFFER_CAPACITY: usize = 8192; - -impl NewDetectSNI { - pub fn new(inner: N, timeout: Timeout) -> Self { - Self { inner, timeout } - } - - pub fn layer(timeout: Timeout) -> impl layer::Layer + Clone { - layer::mk(move |inner| Self::new(inner, timeout)) - } -} - -impl NewService for NewDetectSNI -where - N: Clone, -{ - type Service = DetectSNI; - - fn new_service(&self, target: T) -> Self::Service { - DetectSNI { - target, - inner: self.inner.clone(), - timeout: self.timeout, - } - } -} - -impl Service for DetectSNI -where - T: Clone + Send + 'static, - I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static, - N: NewService<(T, Option), Service = S> + Clone + Send + 'static, - S: Service> + Send, - S::Error: Into, - S::Future: Send, -{ - type Response = S::Response; - type Error = Error; - type Future = Pin> + Send + 'static>>; - - #[inline] - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, io: I) -> Self::Future { - let target = self.target.clone(); - let new_accept = self.inner.clone(); - - // Detect the SNI from a ClientHello (or timeout). - let Timeout(timeout) = self.timeout; - let detect = time::timeout(timeout, detect_sni(io)); - Box::pin(async move { - let (sni, io) = detect.await.map_err(|_| DetectSniTimeoutError(()))??; - - println!("detected SNI: {:?}", sni); - let svc = new_accept.new_service((target, sni)); - svc.oneshot(io).await.map_err(Into::into) - }) - } -} - -/// Peek or buffer the provided stream to determine an SNI value. -async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> -where - I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin, -{ - // First, try to use MSG_PEEK to read the SNI from the TLS ClientHello. We - // use a heap-allocated buffer to avoid creating a large `Future` (since we - // need to hold the buffer across an await). - // - // Anecdotally, the ClientHello sent by Linkerd proxies is <300B. So a ~500B - // byte buffer is more than enough. - let mut buf = BytesMut::with_capacity(PEEK_CAPACITY); - let sz = io.peek(&mut buf).await?; - debug!(sz, "Peeked bytes from TCP stream"); - // Peek may return 0 bytes if the socket is not peekable. - if sz > 0 { - match client_hello::parse_sni(buf.as_ref()) { - Ok(sni) => { - return Ok((sni, EitherIo::Left(io))); - } - - Err(client_hello::Incomplete) => {} - } - } - - // Peeking didn't return enough data, so instead we'll allocate more - // capacity and try reading data from the socket. - debug!("Attempting to buffer TLS ClientHello after incomplete peek"); - let mut buf = BytesMut::with_capacity(BUFFER_CAPACITY); - debug!(buf.capacity = %buf.capacity(), "Reading bytes from TCP stream"); - while io.read_buf(&mut buf).await? != 0 { - debug!(buf.len = %buf.len(), "Read bytes from TCP stream"); - match client_hello::parse_sni(buf.as_ref()) { - Ok(sni) => { - return Ok((sni, EitherIo::Right(PrefixedIo::new(buf.freeze(), io)))); - } - - Err(client_hello::Incomplete) => { - if buf.capacity() == 0 { - // If we can't buffer an entire TLS ClientHello, it - // almost definitely wasn't initiated by another proxy, - // at least. - warn!("Buffer insufficient for TLS ClientHello"); - break; - } - // Continue if there is still buffer capacity. - } - } - } - - trace!("Could not read TLS ClientHello via buffering"); - let io = EitherIo::Right(PrefixedIo::new(buf.freeze(), io)); - Ok((None, io)) -} - -impl From for Timeout { - fn from(d: Duration) -> Self { - Self(d) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use linkerd_io::AsyncWriteExt; - - #[tokio::test(flavor = "current_thread")] - async fn detect_buffered() { - let _trace = linkerd_tracing::test::trace_init(); - - let (mut client_io, server_io) = linkerd_io::duplex(1024); - let input = include_bytes!("detect_sni/testdata/curl-example-com-client-hello.bin"); - let len = input.len(); - let client_task = tokio::spawn(async move { - client_io - .write_all(input) - .await - .expect("Write must succeed"); - }); - - let (sni, io) = detect_sni(server_io) - .await - .expect("SNI detection must not fail"); - - assert_eq!(sni, Some(ServerName("example.com".parse().unwrap()))); - - match io { - EitherIo::Left(_) => panic!("Detected IO should be buffered"), - EitherIo::Right(io) => assert_eq!(io.prefix().len(), len, "All data must be buffered"), - } - - client_task.await.expect("Client must not fail"); - } -} - -#[cfg(fuzzing)] -pub mod fuzz_logic { - use super::*; - - pub fn fuzz_entry(input: &[u8]) { - let _ = client_hello::parse_sni(input); - } -} diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index 2efa2af6d4..0e54d86442 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -2,12 +2,10 @@ #![forbid(unsafe_code)] pub mod client; -pub mod detect_sni; pub mod server; pub use self::{ client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId}, - detect_sni::NewDetectSNI, server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls}, }; diff --git a/linkerd/tls/src/server.rs b/linkerd/tls/src/server.rs index c8b8c5ada8..04862401f9 100644 --- a/linkerd/tls/src/server.rs +++ b/linkerd/tls/src/server.rs @@ -1,9 +1,12 @@ -use crate::{detect_sni::DetectIo, NegotiatedProtocol, ServerName}; +mod client_hello; + +use crate::{NegotiatedProtocol, ServerName}; +use bytes::BytesMut; use futures::prelude::*; use linkerd_conditional::Conditional; use linkerd_error::Error; use linkerd_identity as id; -use linkerd_io::{self as io, EitherIo}; +use linkerd_io::{self as io, AsyncReadExt, EitherIo, PrefixedIo}; use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Param, Service, ServiceExt}; use std::{ fmt, @@ -11,7 +14,9 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{debug, trace}; +use thiserror::Error; +use tokio::time::{self, Duration}; +use tracing::{debug, trace, warn}; /// Describes the authenticated identity of a remote client. #[derive(Clone, Debug, Eq, PartialEq, Hash)] @@ -49,6 +54,8 @@ pub enum NoServerTls { /// Indicates whether TLS was established on an accepted connection. pub type ConditionalServerTls = Conditional; +pub type DetectIo = EitherIo>; + pub type Io = EitherIo>; #[derive(Clone, Debug)] @@ -58,15 +65,31 @@ pub struct NewDetectTls { _local_identity: std::marker::PhantomData L>, } +#[derive(Copy, Clone, Debug)] +pub struct Timeout(pub Duration); + +#[derive(Clone, Debug, Error)] +#[error("TLS detection timed out")] +pub struct ServerTlsTimeoutError(()); + #[derive(Clone, Debug)] pub struct DetectTls { target: T, local_identity: L, + timeout: Timeout, params: P, inner: N, - sni: Option, } +// The initial peek buffer is fairly small so that we can avoid allocating more +// data then we need; but it is large enough to hold the ~300B ClientHello sent +// by proxies. +const PEEK_CAPACITY: usize = 512; + +// A larger fallback buffer is allocated onto the heap if the initial peek +// buffer is insufficient. This is the same value used in HTTP detection. +const BUFFER_CAPACITY: usize = 8192; + impl NewDetectTls { pub fn new(params: P, inner: N) -> Self { Self { @@ -84,27 +107,27 @@ impl NewDetectTls { } } -impl NewService<(T, Option)> for NewDetectTls +impl NewService for NewDetectTls where - P: ExtractParam + Clone, + P: ExtractParam + ExtractParam + Clone, N: Clone, { type Service = DetectTls; - fn new_service(&self, t: (T, Option)) -> Self::Service { - let (target, sni) = t; + fn new_service(&self, target: T) -> Self::Service { + let timeout = self.params.extract_param(&target); let local_identity = self.params.extract_param(&target); DetectTls { target, local_identity, - sni, + timeout, params: self.params.clone(), inner: self.inner.clone(), } } } -impl Service> for DetectTls +impl Service for DetectTls where I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin + 'static, T: Clone + Send + 'static, @@ -127,14 +150,19 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, io: DetectIo) -> Self::Future { + fn call(&mut self, io: I) -> Self::Future { let target = self.target.clone(); let params = self.params.clone(); let new_accept = self.inner.clone(); - let sni = self.sni.clone(); + let tls = self.local_identity.clone(); + // Detect the SNI from a ClientHello (or timeout). + let Timeout(timeout) = self.timeout; + let detect = time::timeout(timeout, detect_sni(io)); Box::pin(async move { + let (sni, io) = detect.await.map_err(|_| ServerTlsTimeoutError(()))??; + let local_server_name = tls.param(); let (peer, io) = match sni { // If we detected an SNI matching this proxy, terminate TLS. @@ -163,6 +191,61 @@ where } } +/// Peek or buffer the provided stream to determine an SNI value. +async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> +where + I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin, +{ + // First, try to use MSG_PEEK to read the SNI from the TLS ClientHello. We + // use a heap-allocated buffer to avoid creating a large `Future` (since we + // need to hold the buffer across an await). + // + // Anecdotally, the ClientHello sent by Linkerd proxies is <300B. So a ~500B + // byte buffer is more than enough. + let mut buf = BytesMut::with_capacity(PEEK_CAPACITY); + let sz = io.peek(&mut buf).await?; + debug!(sz, "Peeked bytes from TCP stream"); + // Peek may return 0 bytes if the socket is not peekable. + if sz > 0 { + match client_hello::parse_sni(buf.as_ref()) { + Ok(sni) => { + return Ok((sni, EitherIo::Left(io))); + } + + Err(client_hello::Incomplete) => {} + } + } + + // Peeking didn't return enough data, so instead we'll allocate more + // capacity and try reading data from the socket. + debug!("Attempting to buffer TLS ClientHello after incomplete peek"); + let mut buf = BytesMut::with_capacity(BUFFER_CAPACITY); + debug!(buf.capacity = %buf.capacity(), "Reading bytes from TCP stream"); + while io.read_buf(&mut buf).await? != 0 { + debug!(buf.len = %buf.len(), "Read bytes from TCP stream"); + match client_hello::parse_sni(buf.as_ref()) { + Ok(sni) => { + return Ok((sni, EitherIo::Right(PrefixedIo::new(buf.freeze(), io)))); + } + + Err(client_hello::Incomplete) => { + if buf.capacity() == 0 { + // If we can't buffer an entire TLS ClientHello, it + // almost definitely wasn't initiated by another proxy, + // at least. + warn!("Buffer insufficient for TLS ClientHello"); + break; + } + // Continue if there is still buffer capacity. + } + } + } + + trace!("Could not read TLS ClientHello via buffering"); + let io = EitherIo::Right(PrefixedIo::new(buf.freeze(), io)); + Ok((None, io)) +} + // === impl ClientId === impl From for ClientId { @@ -221,3 +304,46 @@ impl ServerTls { } } } + +#[cfg(test)] +mod tests { + use super::*; + use linkerd_io::AsyncWriteExt; + + #[tokio::test(flavor = "current_thread")] + async fn detect_buffered() { + let _trace = linkerd_tracing::test::trace_init(); + + let (mut client_io, server_io) = linkerd_io::duplex(1024); + let input = include_bytes!("server/testdata/curl-example-com-client-hello.bin"); + let len = input.len(); + let client_task = tokio::spawn(async move { + client_io + .write_all(input) + .await + .expect("Write must succeed"); + }); + + let (sni, io) = detect_sni(server_io) + .await + .expect("SNI detection must not fail"); + + assert_eq!(sni, Some(ServerName("example.com".parse().unwrap()))); + + match io { + EitherIo::Left(_) => panic!("Detected IO should be buffered"), + EitherIo::Right(io) => assert_eq!(io.prefix().len(), len, "All data must be buffered"), + } + + client_task.await.expect("Client must not fail"); + } +} + +#[cfg(fuzzing)] +pub mod fuzz_logic { + use super::*; + + pub fn fuzz_entry(input: &[u8]) { + let _ = client_hello::parse_sni(input); + } +} diff --git a/linkerd/tls/src/detect_sni/client_hello.rs b/linkerd/tls/src/server/client_hello.rs similarity index 100% rename from linkerd/tls/src/detect_sni/client_hello.rs rename to linkerd/tls/src/server/client_hello.rs diff --git a/linkerd/tls/src/detect_sni/testdata/curl-example-com-client-hello.bin b/linkerd/tls/src/server/testdata/curl-example-com-client-hello.bin similarity index 100% rename from linkerd/tls/src/detect_sni/testdata/curl-example-com-client-hello.bin rename to linkerd/tls/src/server/testdata/curl-example-com-client-hello.bin diff --git a/linkerd/tls/src/detect_sni/testdata/example-com-client-hello.bin b/linkerd/tls/src/server/testdata/example-com-client-hello.bin similarity index 100% rename from linkerd/tls/src/detect_sni/testdata/example-com-client-hello.bin rename to linkerd/tls/src/server/testdata/example-com-client-hello.bin From 22c3a9cf87da3257fd1d7497a52434fe18f275b0 Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Fri, 13 Sep 2024 10:42:58 +0000 Subject: [PATCH 5/5] add independent DetectSni middleware Signed-off-by: Zahari Dichev --- linkerd/tls/src/detect_sni.rs | 107 ++++++++++++++++++++++++++++++++++ linkerd/tls/src/lib.rs | 1 + linkerd/tls/src/server.rs | 2 +- 3 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 linkerd/tls/src/detect_sni.rs diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs new file mode 100644 index 0000000000..45747d63ed --- /dev/null +++ b/linkerd/tls/src/detect_sni.rs @@ -0,0 +1,107 @@ +use crate::{ + server::{detect_sni, DetectIo, Timeout}, + ServerName, +}; +use linkerd_error::Error; +use linkerd_io as io; +use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Service, ServiceExt}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use thiserror::Error; +use tokio::time; +use tracing::debug; + +#[derive(Clone, Debug, Error)] +#[error("SNI detection timed out")] +pub struct SniDetectionTimeoutError; + +#[derive(Clone, Debug, Error)] +#[error("Could not find SNI")] +pub struct NoSniFoundError; + +#[derive(Clone, Debug)] +pub struct NewDetectSni { + params: P, + inner: N, +} + +#[derive(Clone, Debug)] +pub struct DetectSni { + target: T, + inner: N, + timeout: Timeout, + params: P, +} + +impl NewDetectSni { + pub fn new(params: P, inner: N) -> Self { + Self { inner, params } + } + + pub fn layer(params: P) -> impl layer::Layer + Clone + where + P: Clone, + { + layer::mk(move |inner| Self::new(params.clone(), inner)) + } +} + +impl NewService for NewDetectSni +where + P: ExtractParam + Clone, + N: Clone, +{ + type Service = DetectSni; + + fn new_service(&self, target: T) -> Self::Service { + let timeout = self.params.extract_param(&target); + DetectSni { + target, + timeout, + inner: self.inner.clone(), + params: self.params.clone(), + } + } +} + +impl Service for DetectSni +where + T: Clone + Send + Sync + 'static, + P: InsertParam + Clone + Send + Sync + 'static, + P::Target: Send + 'static, + I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static, + N: NewService + Clone + Send + 'static, + S: Service> + Send, + S::Error: Into, + S::Future: Send, +{ + type Response = S::Response; + type Error = Error; + type Future = Pin> + Send + 'static>>; + + #[inline] + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, io: I) -> Self::Future { + let target = self.target.clone(); + let new_accept = self.inner.clone(); + let params = self.params.clone(); + + // Detect the SNI from a ClientHello (or timeout). + let Timeout(timeout) = self.timeout; + let detect = time::timeout(timeout, detect_sni(io)); + Box::pin(async move { + let (sni, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??; + let sni = sni.ok_or(NoSniFoundError)?; + + debug!("detected SNI: {:?}", sni); + let svc = new_accept.new_service(params.insert_param(sni, target)); + svc.oneshot(io).await.map_err(Into::into) + }) + } +} diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index 0e54d86442..0a281e2b36 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -2,6 +2,7 @@ #![forbid(unsafe_code)] pub mod client; +pub mod detect_sni; pub mod server; pub use self::{ diff --git a/linkerd/tls/src/server.rs b/linkerd/tls/src/server.rs index 04862401f9..1c85c92ee6 100644 --- a/linkerd/tls/src/server.rs +++ b/linkerd/tls/src/server.rs @@ -192,7 +192,7 @@ where } /// Peek or buffer the provided stream to determine an SNI value. -async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> +pub(crate) async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> where I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin, {