diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b2ccf591ec15..b57a699e9f57 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -812,6 +812,23 @@ jobs: # Run the tests! - run: cargo test -p wasmtime-wasi-nn --features ${{ matrix.feature }} + # Test `wasmtime-wasi-tls-nativetls` in its own job. This is because it + # depends on OpenSSL, which is not easily available on all platforms. + test_wasi_tls_nativetls: + name: Test wasi-tls using native-tls provider + needs: determine + if: needs.determine.outputs.run-full + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - uses: ./.github/actions/install-rust + - run: cargo test -p wasmtime-wasi-tls-nativetls + # Test the `wasmtime-fuzzing` crate. Split out from the main tests because # `--all-features` brings in OCaml, which is a pain to get setup for all # targets. @@ -1114,6 +1131,7 @@ jobs: - doc - micro_checks - special_tests + - test_wasi_tls_nativetls - clippy - monolith_checks - platform_checks diff --git a/Cargo.lock b/Cargo.lock index d594cac0fe3d..70b93f148663 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -627,6 +627,16 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.6" @@ -1371,6 +1381,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -2327,6 +2352,23 @@ dependencies = [ "rand", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndarray" version = "0.15.6" @@ -2462,6 +2504,50 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "openssl" +version = "0.10.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "openvino" version = "0.9.0" @@ -3014,6 +3100,38 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.17" @@ -3502,6 +3620,16 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.25.0" @@ -3767,6 +3895,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "veri_engine" version = "0.1.0" @@ -4773,6 +4907,21 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "wasmtime-wasi-tls-nativetls" +version = "35.0.0" +dependencies = [ + "anyhow", + "futures", + "native-tls", + "test-programs-artifacts", + "tokio", + "tokio-native-tls", + "wasmtime", + "wasmtime-wasi", + "wasmtime-wasi-tls", +] + [[package]] name = "wasmtime-wast" version = "35.0.0" diff --git a/Cargo.toml b/Cargo.toml index 0634c52dc5d8..f52f1cb2bee2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -158,6 +158,7 @@ members = [ "crates/test-programs", "crates/wasi-preview1-component-adapter", "crates/wasi-preview1-component-adapter/verify", + "crates/wasi-tls-nativetls", "examples/fib-debug/wasm", "examples/wasm", "examples/tokio/wasm", @@ -235,6 +236,7 @@ wasmtime-wasi-config = { path = "crates/wasi-config", version = "35.0.0" } wasmtime-wasi-keyvalue = { path = "crates/wasi-keyvalue", version = "35.0.0" } wasmtime-wasi-threads = { path = "crates/wasi-threads", version = "35.0.0" } wasmtime-wasi-tls = { path = "crates/wasi-tls", version = "35.0.0" } +wasmtime-wasi-tls-nativetls = { path = "crates/wasi-tls-nativetls", version = "35.0.0" } wasmtime-wast = { path = "crates/wast", version = "=35.0.0" } # Internal Wasmtime-specific crates. @@ -399,6 +401,8 @@ ittapi = "0.4.0" libm = "0.2.7" tokio-rustls = "0.25.0" rustls = "0.22.0" +tokio-native-tls = "0.3.1" +native-tls = "0.2.11" webpki-roots = "0.26.0" itertools = "0.14.0" base64 = "0.22.1" diff --git a/ci/run-tests.py b/ci/run-tests.py index 3ddb5983a983..8f354788d63a 100755 --- a/ci/run-tests.py +++ b/ci/run-tests.py @@ -7,6 +7,9 @@ # - wasmtime-wasi-nn: mutually-exclusive features that aren't available for all # targets, needs its own CI job. # +# - wasmtime-wasi-tls-nativetls: the openssl dependency does not play nice with +# cross compilation. This crate is tested in a separate CI job. +# # - wasmtime-fuzzing: enabling all features brings in OCaml which is a pain to # configure for all targets, so it has its own CI job. # @@ -21,6 +24,7 @@ args = ['cargo', 'test', '--workspace', '--all-features'] args.append('--exclude=test-programs') args.append('--exclude=wasmtime-wasi-nn') +args.append('--exclude=wasmtime-wasi-tls-nativetls') args.append('--exclude=wasmtime-fuzzing') args.append('--exclude=wasm-spec-interpreter') args.append('--exclude=veri_engine') diff --git a/crates/test-programs/src/bin/tls_sample_application.rs b/crates/test-programs/src/bin/tls_sample_application.rs index 2c570fddedff..6fa7a8344262 100644 --- a/crates/test-programs/src/bin/tls_sample_application.rs +++ b/crates/test-programs/src/bin/tls_sample_application.rs @@ -7,8 +7,9 @@ use test_programs::wasi::tls::types::ClientHandshake; const PORT: u16 = 443; fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { - let request = - format!("GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\n\r\n"); + let request = format!( + "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n" + ); let net = Network::default(); @@ -25,13 +26,13 @@ fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { tls_output .blocking_write_util(request.as_bytes()) .context("writing http request failed")?; - client_connection - .blocking_close_output(&tls_output) - .context("closing tls connection failed")?; - socket.shutdown(ShutdownType::Send)?; let response = tls_input .blocking_read_to_end() .context("reading http response failed")?; + client_connection + .blocking_close_output(&tls_output) + .context("closing tls connection failed")?; + socket.shutdown(ShutdownType::Both)?; if String::from_utf8(response)?.contains("HTTP/1.1 200 OK") { Ok(()) @@ -55,7 +56,7 @@ fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() { // We're expecting an error regarding the "certificate" is some form or - // another. When we add more TLS backends other than rustls, this naive + // another. When we add more TLS backends this naive // check will likely need to be revisited/expanded: Err(e) if e.to_debug_string().contains("certificate") => Ok(()), diff --git a/crates/wasi-tls-nativetls/Cargo.toml b/crates/wasi-tls-nativetls/Cargo.toml new file mode 100644 index 000000000000..6449ec8a7fa5 --- /dev/null +++ b/crates/wasi-tls-nativetls/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "wasmtime-wasi-tls-nativetls" +version.workspace = true +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +repository = "https://github.com/bytecodealliance/wasmtime" +license = "Apache-2.0 WITH LLVM-exception" +description = "Wasmtime implementation of the wasi-tls API, using native-tls for TLS support." + +[lints] +workspace = true + +[dependencies] +wasmtime-wasi-tls = { workspace = true } +tokio = { workspace = true } +tokio-native-tls = { workspace = true } +native-tls = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } +test-programs-artifacts = { workspace = true } +wasmtime = { workspace = true, features = ["runtime", "component-model"] } +wasmtime-wasi = { workspace = true } +tokio = { workspace = true, features = ["macros"] } +futures = { workspace = true } diff --git a/crates/wasi-tls-nativetls/src/lib.rs b/crates/wasi-tls-nativetls/src/lib.rs new file mode 100644 index 000000000000..488614512dcb --- /dev/null +++ b/crates/wasi-tls-nativetls/src/lib.rs @@ -0,0 +1,82 @@ +//! The `native_tls` provider. + +use std::{io, pin::pin}; + +use wasmtime_wasi_tls::{TlsProvider, TlsStream, TlsTransport}; + +type BoxFuture = std::pin::Pin + Send>>; + +/// The `native_tls` provider. +pub struct NativeTlsProvider { + _priv: (), +} + +impl TlsProvider for NativeTlsProvider { + fn connect( + &self, + server_name: String, + transport: Box, + ) -> BoxFuture>> { + async fn connect_impl( + server_name: String, + transport: Box, + ) -> Result { + let connector = native_tls::TlsConnector::new()?; + let stream = tokio_native_tls::TlsConnector::from(connector) + .connect(&server_name, transport) + .await?; + Ok(NativeTlsStream(stream)) + } + + Box::pin(async move { + let stream = connect_impl(server_name, transport) + .await + .map_err(|e| io::Error::other(e))?; + Ok(Box::new(stream) as Box) + }) + } +} + +impl Default for NativeTlsProvider { + fn default() -> Self { + Self { _priv: () } + } +} + +struct NativeTlsStream(tokio_native_tls::TlsStream>); + +impl TlsStream for NativeTlsStream {} + +impl tokio::io::AsyncRead for NativeTlsStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + pin!(&mut self.as_mut().0).poll_read(cx, buf) + } +} + +impl tokio::io::AsyncWrite for NativeTlsStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + pin!(&mut self.as_mut().0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(&mut self.as_mut().0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(&mut self.as_mut().0).poll_shutdown(cx) + } +} diff --git a/crates/wasi-tls-nativetls/tests/main.rs b/crates/wasi-tls-nativetls/tests/main.rs new file mode 100644 index 000000000000..d86fe08f6371 --- /dev/null +++ b/crates/wasi-tls-nativetls/tests/main.rs @@ -0,0 +1,72 @@ +use anyhow::{Result, anyhow}; +use wasmtime::{ + Store, + component::{Component, Linker, ResourceTable}, +}; +use wasmtime_wasi::p2::{IoView, WasiCtx, WasiCtxBuilder, WasiView, bindings::Command}; +use wasmtime_wasi_tls::{LinkOptions, WasiTls, WasiTlsCtx, WasiTlsCtxBuilder}; + +struct Ctx { + table: ResourceTable, + wasi_ctx: WasiCtx, + wasi_tls_ctx: WasiTlsCtx, +} + +impl IoView for Ctx { + fn table(&mut self) -> &mut ResourceTable { + &mut self.table + } +} +impl WasiView for Ctx { + fn ctx(&mut self) -> &mut WasiCtx { + &mut self.wasi_ctx + } +} + +async fn run_test(path: &str) -> Result<()> { + let provider = Box::new(wasmtime_wasi_tls_nativetls::NativeTlsProvider::default()); + let ctx = Ctx { + table: ResourceTable::new(), + wasi_ctx: WasiCtxBuilder::new() + .inherit_stderr() + .inherit_network() + .allow_ip_name_lookup(true) + .build(), + wasi_tls_ctx: WasiTlsCtxBuilder::new().provider(provider).build(), + }; + + let engine = test_programs_artifacts::engine(|config| { + config.async_support(true); + }); + let mut store = Store::new(&engine, ctx); + let component = Component::from_file(&engine, path)?; + + let mut linker = Linker::new(&engine); + wasmtime_wasi::p2::add_to_linker_async(&mut linker)?; + let mut opts = LinkOptions::default(); + opts.tls(true); + wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| { + WasiTls::new(&h.wasi_tls_ctx, &mut h.table) + })?; + + let command = Command::instantiate_async(&mut store, &component, &linker).await?; + command + .wasi_cli_run() + .call_run(&mut store) + .await? + .map_err(|()| anyhow!("command returned with failing exit status")) +} + +macro_rules! assert_test_exists { + ($name:ident) => { + #[expect(unused_imports, reason = "just here to assert it exists")] + use self::$name as _; + }; +} + +test_programs_artifacts::foreach_tls!(assert_test_exists); + +#[tokio::test(flavor = "multi_thread")] +async fn tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await +} diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml index 7bc7a29c26dc..be715c5b6d1b 100644 --- a/crates/wasi-tls/Cargo.toml +++ b/crates/wasi-tls/Cargo.toml @@ -18,14 +18,15 @@ tokio = { workspace = true, features = [ "net", "rt-multi-thread", "time", + "io-util", ] } wasmtime = { workspace = true, features = ["runtime", "component-model"] } wasmtime-wasi = { workspace = true } + tokio-rustls = { workspace = true } rustls = { workspace = true } webpki-roots = { workspace = true } - [dev-dependencies] test-programs-artifacts = { workspace = true } wasmtime-wasi = { workspace = true } diff --git a/crates/wasi-tls/src/bindings.rs b/crates/wasi-tls/src/bindings.rs new file mode 100644 index 000000000000..355034ee512b --- /dev/null +++ b/crates/wasi-tls/src/bindings.rs @@ -0,0 +1,21 @@ +//! Auto-generated bindings. + +#[expect(missing_docs, reason = "bindgen-generated code")] +mod generated { + wasmtime::component::bindgen!({ + path: "wit", + world: "wasi:tls/imports", + with: { + "wasi:io": wasmtime_wasi::p2::bindings::io, + "wasi:tls/types/client-connection": crate::HostClientConnection, + "wasi:tls/types/client-handshake": crate::HostClientHandshake, + "wasi:tls/types/future-client-streams": crate::HostFutureClientStreams, + }, + trappable_imports: true, + async: { + only_imports: [], + } + }); +} + +pub use generated::wasi::tls::*; diff --git a/crates/wasi-tls/src/host.rs b/crates/wasi-tls/src/host.rs new file mode 100644 index 000000000000..24b21ac14631 --- /dev/null +++ b/crates/wasi-tls/src/host.rs @@ -0,0 +1,156 @@ +use anyhow::Result; +use wasmtime::component::Resource; +use wasmtime_wasi::async_trait; +use wasmtime_wasi::p2::Pollable; +use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable, IoError}; + +use crate::{ + TlsStream, TlsTransport, WasiTls, bindings, + io::{ + AsyncReadStream, AsyncWriteStream, FutureOutput, WasiFuture, WasiStreamReader, + WasiStreamWriter, + }, +}; + +impl<'a> bindings::types::Host for WasiTls<'a> {} + +/// Represents the ClientHandshake which will be used to configure the handshake +pub struct HostClientHandshake { + server_name: String, + transport: Box, +} + +impl<'a> bindings::types::HostClientHandshake for WasiTls<'a> { + fn new( + &mut self, + server_name: String, + input: Resource, + output: Resource, + ) -> wasmtime::Result> { + let input = self.table.delete(input)?; + let output = self.table.delete(output)?; + + let reader = WasiStreamReader::new(input); + let writer = WasiStreamWriter::new(output); + let transport = tokio::io::join(reader, writer); + + Ok(self.table.push(HostClientHandshake { + server_name, + transport: Box::new(transport) as Box, + })?) + } + + fn finish( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + let handshake = self.table.delete(this)?; + + let connect = self + .ctx + .provider + .connect(handshake.server_name, handshake.transport); + + let future = HostFutureClientStreams(WasiFuture::spawn(async move { + let tls_stream = connect.await?; + + let (rx, tx) = tokio::io::split(tls_stream); + let write_stream = AsyncWriteStream::new(tx); + let client = HostClientConnection(write_stream.clone()); + + let input = Box::new(AsyncReadStream::new(rx)) as DynInputStream; + let output = Box::new(write_stream) as DynOutputStream; + + Ok((client, input, output)) + })); + + Ok(self.table.push(future)?) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} + +/// Future streams provides the tls streams after the handshake is completed +pub struct HostFutureClientStreams( + WasiFuture>, +); + +#[async_trait] +impl Pollable for HostFutureClientStreams { + async fn ready(&mut self) { + self.0.ready().await + } +} + +impl<'a> bindings::types::HostFutureClientStreams for WasiTls<'a> { + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + wasmtime_wasi::p2::subscribe(self.table, this) + } + + fn get( + &mut self, + this: Resource, + ) -> wasmtime::Result< + Option< + Result< + Result< + ( + Resource, + Resource, + Resource, + ), + Resource, + >, + (), + >, + >, + > { + let future = self.table.get_mut(&this)?; + + let result = match future.0.get() { + FutureOutput::Ready(Ok((client, input, output))) => { + let client = self.table.push(client)?; + let input = self.table.push_child(input, &client)?; + let output = self.table.push_child(output, &client)?; + + Some(Ok(Ok((client, input, output)))) + } + FutureOutput::Ready(Err(io_error)) => { + let io_error = self.table.push(io_error)?; + + Some(Ok(Err(io_error))) + } + FutureOutput::Consumed => Some(Err(())), + FutureOutput::Pending => None, + }; + + Ok(result) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} + +/// Represents the client connection and used to shut down the tls stream +pub struct HostClientConnection( + crate::io::AsyncWriteStream>>, +); + +impl<'a> bindings::types::HostClientConnection for WasiTls<'a> { + fn close_output(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.get_mut(&this)?.0.close() + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} diff --git a/crates/wasi-tls/src/io.rs b/crates/wasi-tls/src/io.rs new file mode 100644 index 000000000000..f2e408b7ef80 --- /dev/null +++ b/crates/wasi-tls/src/io.rs @@ -0,0 +1,408 @@ +//! Utility types for converting Rust & Tokio I/O types into WASI I/O types, +//! and vice versa. + +use anyhow::Result; +use bytes::Bytes; +use std::io; +use std::sync::Arc; +use std::task::{Poll, ready}; +use std::{future::Future, mem, pin::Pin}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::sync::Mutex; +use wasmtime_wasi::async_trait; +use wasmtime_wasi::p2::{ + DynInputStream, DynOutputStream, OutputStream, Pollable, StreamError, StreamResult, +}; +use wasmtime_wasi::runtime::AbortOnDropJoinHandle; + +enum FutureState { + Pending(Pin + Send>>), + Ready(T), + Consumed, +} + +pub(crate) enum FutureOutput { + Pending, + Ready(T), + Consumed, +} + +pub(crate) struct WasiFuture(FutureState); + +impl WasiFuture +where + T: Send + 'static, +{ + pub(crate) fn spawn(fut: F) -> Self + where + F: Future + Send + 'static, + { + Self(FutureState::Pending(Box::pin( + wasmtime_wasi::runtime::spawn(async move { fut.await }), + ))) + } + + pub(crate) fn get(&mut self) -> FutureOutput { + match &self.0 { + FutureState::Pending(_) => return FutureOutput::Pending, + FutureState::Consumed => return FutureOutput::Consumed, + FutureState::Ready(_) => (), + } + + let FutureState::Ready(value) = mem::replace(&mut self.0, FutureState::Consumed) else { + unreachable!() + }; + + FutureOutput::Ready(value) + } +} + +#[async_trait] +impl Pollable for WasiFuture +where + T: Send + 'static, +{ + async fn ready(&mut self) { + match &mut self.0 { + FutureState::Ready(_) | FutureState::Consumed => return, + FutureState::Pending(task) => self.0 = FutureState::Ready(task.as_mut().await), + } + } +} + +pub(crate) struct WasiStreamReader(FutureState); +impl WasiStreamReader { + pub(crate) fn new(stream: DynInputStream) -> Self { + Self(FutureState::Ready(stream)) + } +} +impl AsyncRead for WasiStreamReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + loop { + let stream = match &mut self.0 { + FutureState::Ready(stream) => stream, + FutureState::Pending(fut) => { + let stream = ready!(fut.as_mut().poll(cx)); + self.0 = FutureState::Ready(stream); + if let FutureState::Ready(stream) = &mut self.0 { + stream + } else { + unreachable!() + } + } + FutureState::Consumed => { + return Poll::Ready(Ok(())); + } + }; + match stream.read(buf.remaining()) { + Ok(bytes) if bytes.is_empty() => { + let FutureState::Ready(mut stream) = + std::mem::replace(&mut self.0, FutureState::Consumed) + else { + unreachable!() + }; + + self.0 = FutureState::Pending(Box::pin(async move { + stream.ready().await; + stream + })); + } + Ok(bytes) => { + buf.put_slice(&bytes); + + return Poll::Ready(Ok(())); + } + Err(StreamError::Closed) => { + self.0 = FutureState::Consumed; + return Poll::Ready(Ok(())); + } + Err(e) => { + self.0 = FutureState::Consumed; + return Poll::Ready(Err(std::io::Error::other(e))); + } + } + } + } +} + +pub(crate) struct WasiStreamWriter(FutureState); +impl WasiStreamWriter { + pub(crate) fn new(stream: DynOutputStream) -> Self { + Self(FutureState::Ready(stream)) + } +} +impl AsyncWrite for WasiStreamWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + loop { + match &mut self.as_mut().0 { + FutureState::Consumed => unreachable!(), + FutureState::Pending(future) => { + let value = ready!(future.as_mut().poll(cx)); + self.as_mut().0 = FutureState::Ready(value); + } + FutureState::Ready(output) => { + match output.check_write() { + Ok(0) => { + let FutureState::Ready(mut output) = + mem::replace(&mut self.as_mut().0, FutureState::Consumed) + else { + unreachable!() + }; + self.as_mut().0 = FutureState::Pending(Box::pin(async move { + output.ready().await; + output + })); + } + Ok(count) => { + let count = count.min(buf.len()); + return match output.write(Bytes::copy_from_slice(&buf[..count])) { + Ok(()) => Poll::Ready(Ok(count)), + Err(StreamError::Closed) => Poll::Ready(Ok(0)), + Err(e) => Poll::Ready(Err(std::io::Error::other(e))), + }; + } + Err(StreamError::Closed) => { + // Our current version of tokio-rustls does not handle returning `Ok(0)` well. + // See: https://github.com/rustls/tokio-rustls/issues/92 + return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); + } + Err(e) => return Poll::Ready(Err(std::io::Error::other(e))), + }; + } + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_write(cx, &[]).map(|v| v.map(drop)) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_flush(cx) + } +} + +pub(crate) use wasmtime_wasi::p2::pipe::AsyncReadStream; + +pub(crate) struct AsyncWriteStream(Arc>>); + +impl AsyncWriteStream +where + IO: AsyncWrite + Send + Unpin + 'static, +{ + pub(crate) fn new(io: IO) -> Self { + AsyncWriteStream(Arc::new(Mutex::new(WriteState::new(io)))) + } + + pub(crate) fn close(&mut self) -> wasmtime::Result<()> { + self.try_lock()?.close(); + Ok(()) + } + + async fn lock(&self) -> tokio::sync::MutexGuard<'_, WriteState> { + self.0.lock().await + } + + fn try_lock(&self) -> Result>, StreamError> { + self.0 + .try_lock() + .map_err(|_| StreamError::trap("concurrent access to resource not supported")) + } +} +impl Clone for AsyncWriteStream { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } +} + +#[async_trait] +impl OutputStream for AsyncWriteStream +where + IO: AsyncWrite + Send + Unpin + 'static, +{ + fn write(&mut self, bytes: bytes::Bytes) -> StreamResult<()> { + self.try_lock()?.write(bytes) + } + + fn flush(&mut self) -> StreamResult<()> { + self.try_lock()?.flush() + } + + fn check_write(&mut self) -> StreamResult { + self.try_lock()?.check_write() + } + + async fn cancel(&mut self) { + self.lock().await.cancel().await + } +} + +#[async_trait] +impl Pollable for AsyncWriteStream +where + IO: AsyncWrite + Send + Unpin + 'static, +{ + async fn ready(&mut self) { + self.lock().await.ready().await + } +} + +enum WriteState { + Ready(IO), + Writing(AbortOnDropJoinHandle>), + Flushing(AbortOnDropJoinHandle>), + Closing(AbortOnDropJoinHandle>), + Closed, + Error(io::Error), +} +const READY_SIZE: usize = 1024 * 1024 * 1024; + +impl WriteState +where + IO: AsyncWrite + Send + Unpin + 'static, +{ + fn new(stream: IO) -> Self { + Self::Ready(stream) + } + + fn write(&mut self, mut bytes: bytes::Bytes) -> StreamResult<()> { + let WriteState::Ready(_) = self else { + return Err(StreamError::Trap(anyhow::anyhow!( + "unpermitted: must call check_write first" + ))); + }; + + if bytes.is_empty() { + return Ok(()); + } + + let WriteState::Ready(mut stream) = std::mem::replace(self, WriteState::Closed) else { + unreachable!() + }; + + *self = WriteState::Writing(wasmtime_wasi::runtime::spawn(async move { + while !bytes.is_empty() { + let n = stream.write(&bytes).await?; + let _ = bytes.split_to(n); + } + + Ok(stream) + })); + + Ok(()) + } + + fn flush(&mut self) -> StreamResult<()> { + match self { + // Immediately flush: + WriteState::Ready(_) => { + let WriteState::Ready(mut stream) = std::mem::replace(self, WriteState::Closed) + else { + unreachable!() + }; + *self = WriteState::Flushing(wasmtime_wasi::runtime::spawn(async move { + stream.flush().await?; + Ok(stream) + })); + } + + // Schedule the flush after the current write has finished: + WriteState::Writing(_) => { + let WriteState::Writing(write) = std::mem::replace(self, WriteState::Closed) else { + unreachable!() + }; + *self = WriteState::Flushing(wasmtime_wasi::runtime::spawn(async move { + let mut stream = write.await?; + stream.flush().await?; + Ok(stream) + })); + } + + WriteState::Flushing(_) | WriteState::Closing(_) | WriteState::Error(_) => {} + WriteState::Closed => return Err(StreamError::Closed), + } + + Ok(()) + } + + fn check_write(&mut self) -> StreamResult { + match self { + WriteState::Ready(_) => Ok(READY_SIZE), + WriteState::Writing(_) => Ok(0), + WriteState::Flushing(_) => Ok(0), + WriteState::Closing(_) => Ok(0), + WriteState::Closed => Err(StreamError::Closed), + WriteState::Error(_) => { + let WriteState::Error(e) = std::mem::replace(self, WriteState::Closed) else { + unreachable!() + }; + + Err(StreamError::LastOperationFailed(e.into())) + } + } + } + + fn close(&mut self) { + match std::mem::replace(self, WriteState::Closed) { + // No write in progress, immediately shut down: + WriteState::Ready(mut stream) => { + *self = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { + stream.shutdown().await + })); + } + + // Schedule the shutdown after the current operation has finished: + WriteState::Writing(op) | WriteState::Flushing(op) => { + *self = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { + let mut stream = op.await?; + stream.shutdown().await + })); + } + + WriteState::Closing(t) => { + *self = WriteState::Closing(t); + } + WriteState::Closed | WriteState::Error(_) => {} + } + } + + async fn cancel(&mut self) { + match std::mem::replace(self, WriteState::Closed) { + WriteState::Writing(task) | WriteState::Flushing(task) => _ = task.cancel().await, + WriteState::Closing(task) => _ = task.cancel().await, + _ => {} + } + } + + async fn ready(&mut self) { + match self { + WriteState::Writing(task) | WriteState::Flushing(task) => { + *self = match task.await { + Ok(s) => WriteState::Ready(s), + Err(e) => WriteState::Error(e), + } + } + WriteState::Closing(task) => { + *self = match task.await { + Ok(()) => WriteState::Closed, + Err(e) => WriteState::Error(e), + } + } + _ => {} + } + } +} diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index 2e8733c9c3ff..28c96d8a4626 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -13,11 +13,12 @@ //! component::{Linker, ResourceTable}, //! Store, Engine, Result, Config //! }; -//! use wasmtime_wasi_tls::{LinkOptions, WasiTlsCtx}; +//! use wasmtime_wasi_tls::{LinkOptions, WasiTls, WasiTlsCtx, WasiTlsCtxBuilder}; //! //! struct Ctx { //! table: ResourceTable, //! wasi_ctx: WasiCtx, +//! wasi_tls_ctx: WasiTlsCtx, //! } //! //! impl IoView for Ctx { @@ -41,6 +42,10 @@ //! .inherit_network() //! .allow_ip_name_lookup(true) //! .build(), +//! wasi_tls_ctx: WasiTlsCtxBuilder::new() +//! // Optionally, configure a different TLS provider: +//! // .provider(Box::new(wasmtime_wasi_tls_nativetls::NativeTlsProvider::default())) +//! .build(), //! }; //! //! let mut config = Config::new(); @@ -56,7 +61,7 @@ //! let mut opts = LinkOptions::default(); //! opts.tls(true); //! wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| { -//! WasiTlsCtx::new(&mut h.table) +//! WasiTls::new(&h.wasi_tls_ctx, &mut h.table) //! })?; //! //! // ... use `linker` to instantiate within `store` ... @@ -71,641 +76,101 @@ #![doc(test(attr(deny(warnings))))] #![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))] -use anyhow::Result; -use bytes::Bytes; -use rustls::pki_types::ServerName; -use std::io; -use std::sync::Arc; -use std::task::{Poll, ready}; -use std::{future::Future, mem, pin::Pin, sync::LazyLock}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::Mutex; -use tokio_rustls::client::TlsStream; -use wasmtime::component::{HasData, Resource, ResourceTable}; -use wasmtime_wasi::async_trait; -use wasmtime_wasi::p2::bindings::io::{ - error::Error as HostIoError, - poll::Pollable as HostPollable, - streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream}, -}; -use wasmtime_wasi::p2::pipe::AsyncReadStream; -use wasmtime_wasi::p2::{OutputStream, Pollable, StreamError}; -use wasmtime_wasi::runtime::AbortOnDropJoinHandle; +use tokio::io::{AsyncRead, AsyncWrite}; +use wasmtime::component::{HasData, ResourceTable}; -mod gen_ { - wasmtime::component::bindgen!({ - path: "wit", - world: "wasi:tls/imports", - with: { - "wasi:io": wasmtime_wasi::p2::bindings::io, - "wasi:tls/types/client-connection": super::ClientConnection, - "wasi:tls/types/client-handshake": super::ClientHandShake, - "wasi:tls/types/future-client-streams": super::FutureClientStreams, - }, - trappable_imports: true, - async: { - only_imports: [], - } - }); -} -pub use gen_::wasi::tls::types::LinkOptions; -use gen_::wasi::tls::{self as generated}; +pub mod bindings; +mod host; +mod io; +mod rustls; -fn default_client_config() -> Arc { - static CONFIG: LazyLock> = LazyLock::new(|| { - let roots = rustls::RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.into(), - }; - let config = rustls::ClientConfig::builder() - .with_root_certificates(roots) - .with_no_client_auth(); - Arc::new(config) - }); - Arc::clone(&CONFIG) -} +pub use bindings::types::LinkOptions; +pub use host::{HostClientConnection, HostClientHandshake, HostFutureClientStreams}; +pub use rustls::RustlsProvider; -/// Wasi TLS context needed fro internal `wasi-tls`` state -pub struct WasiTlsCtx<'a> { +/// Capture the state necessary for use in the `wasi-tls` API implementation. +pub struct WasiTls<'a> { + ctx: &'a WasiTlsCtx, table: &'a mut ResourceTable, } -impl<'a> WasiTlsCtx<'a> { +impl<'a> WasiTls<'a> { /// Create a new Wasi TLS context - pub fn new(table: &'a mut ResourceTable) -> Self { - Self { table } + pub fn new(ctx: &'a WasiTlsCtx, table: &'a mut ResourceTable) -> Self { + Self { ctx, table } } } -impl<'a> generated::types::Host for WasiTlsCtx<'a> {} - /// Add the `wasi-tls` world's types to a [`wasmtime::component::Linker`]. pub fn add_to_linker( l: &mut wasmtime::component::Linker, opts: &mut LinkOptions, - f: fn(&mut T) -> WasiTlsCtx<'_>, -) -> Result<()> { - generated::types::add_to_linker::<_, WasiTls>(l, &opts, f)?; + f: fn(&mut T) -> WasiTls<'_>, +) -> anyhow::Result<()> { + bindings::types::add_to_linker::<_, HasWasiTls>(l, &opts, f)?; Ok(()) } -struct WasiTls; - -impl HasData for WasiTls { - type Data<'a> = WasiTlsCtx<'a>; -} - -enum TlsError { - /// The component should trap. Under normal circumstances, this only occurs - /// when the underlying transport stream returns [`StreamError::Trap`]. - Trap(anyhow::Error), - - /// A failure indicated by the underlying transport stream as - /// [`StreamError::LastOperationFailed`]. - Io(wasmtime_wasi::p2::IoError), - - /// A TLS protocol error occurred. - Tls(rustls::Error), -} - -impl TlsError { - /// Create a [`TlsError::Tls`] error from a simple message. - fn msg(msg: &str) -> Self { - // (Ab)using rustls' error type to synthesize our own TLS errors: - Self::Tls(rustls::Error::General(msg.to_string())) - } -} - -impl From for TlsError { - fn from(error: io::Error) -> Self { - // Report unexpected EOFs as an error to prevent truncation attacks. - // See: https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read - if let io::ErrorKind::WriteZero | io::ErrorKind::UnexpectedEof = error.kind() { - return Self::msg("underlying transport closed abruptly"); - } - - // Errors from underlying transport. - // These have been wrapped inside `io::Error`s by our wasi-to-tokio stream transformer below. - let error = match error.downcast::() { - Ok(StreamError::LastOperationFailed(e)) => return Self::Io(e), - Ok(StreamError::Trap(e)) => return Self::Trap(e), - Ok(StreamError::Closed) => unreachable!( - "our wasi-to-tokio stream transformer should have translated this to a 0-sized read" - ), - Err(e) => e, - }; - - // Errors from `rustls`. - // These have been wrapped inside `io::Error`s by `tokio-rustls`. - let error = match error.downcast::() { - Ok(e) => return Self::Tls(e), - Err(e) => e, - }; - - // All errors should have been handled by the clauses above. - Self::Trap(anyhow::Error::new(error).context("unknown wasi-tls error")) - } -} - -/// Represents the ClientHandshake which will be used to configure the handshake -pub struct ClientHandShake { - server_name: String, - streams: WasiStreams, +struct HasWasiTls; +impl HasData for HasWasiTls { + type Data<'a> = WasiTls<'a>; } -impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> { - fn new( - &mut self, - server_name: String, - input: Resource, - output: Resource, - ) -> wasmtime::Result> { - let input = self.table.delete(input)?; - let output = self.table.delete(output)?; - Ok(self.table.push(ClientHandShake { - server_name, - streams: WasiStreams { - input: StreamState::Ready(input), - output: StreamState::Ready(output), - }, - })?) - } - - fn finish( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result> { - let handshake = self.table.delete(this)?; - let server_name = handshake.server_name; - let streams = handshake.streams; - - Ok(self - .table - .push(FutureStreams(StreamState::Pending(Box::pin(async move { - let domain = ServerName::try_from(server_name) - .map_err(|_| TlsError::msg("invalid server name"))?; - - let stream = tokio_rustls::TlsConnector::from(default_client_config()) - .connect(domain, streams) - .await?; - Ok(stream) - }))))?) - } - - fn drop( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result<()> { - self.table.delete(this)?; - Ok(()) - } +/// Builder-style structure used to create a [`WasiTlsCtx`]. +pub struct WasiTlsCtxBuilder { + provider: Box, } -/// Future streams provides the tls streams after the handshake is completed -pub struct FutureStreams(StreamState>); - -/// Library specific version of TLS connection after the handshake is completed. -/// This alias allows it to use with wit-bindgen component generator which won't take generic types -pub type FutureClientStreams = FutureStreams>; - -#[async_trait] -impl Pollable for FutureStreams { - async fn ready(&mut self) { - match &mut self.0 { - StreamState::Ready(_) | StreamState::Closed => return, - StreamState::Pending(task) => self.0 = StreamState::Ready(task.as_mut().await), - } - } -} - -impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> { - fn subscribe( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result> { - wasmtime_wasi::p2::subscribe(self.table, this) - } - - fn get( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result< - Option< - Result< - Result< - ( - Resource, - Resource, - Resource, - ), - Resource, - >, - (), - >, - >, - > { - let this = &mut self.table.get_mut(&this)?.0; - match this { - StreamState::Pending(_) => return Ok(None), - StreamState::Closed => return Ok(Some(Err(()))), - StreamState::Ready(_) => (), - } - - let StreamState::Ready(result) = mem::replace(this, StreamState::Closed) else { - unreachable!() - }; - - let tls_stream = match result { - Ok(s) => s, - Err(TlsError::Trap(e)) => return Err(e), - Err(TlsError::Io(e)) => { - let error = self.table.push(e)?; - return Ok(Some(Ok(Err(error)))); - } - Err(TlsError::Tls(e)) => { - let error = self.table.push(wasmtime_wasi::p2::IoError::new(e))?; - return Ok(Some(Ok(Err(error)))); - } - }; - - let (rx, tx) = tokio::io::split(tls_stream); - let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx)); - let client = ClientConnection { - writer: write_stream.clone(), - }; - - let input = Box::new(AsyncReadStream::new(rx)) as BoxInputStream; - let output = Box::new(write_stream) as BoxOutputStream; - - let client = self.table.push(client)?; - let input = self.table.push_child(input, &client)?; - let output = self.table.push_child(output, &client)?; - - Ok(Some(Ok(Ok((client, input, output))))) - } - - fn drop( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result<()> { - self.table.delete(this)?; - Ok(()) - } -} - -/// Represents the client connection and used to shut down the tls stream -pub struct ClientConnection { - writer: AsyncTlsWriteStream, -} - -impl<'a> generated::types::HostClientConnection for WasiTlsCtx<'a> { - fn close_output(&mut self, this: Resource) -> wasmtime::Result<()> { - self.table.get_mut(&this)?.writer.close() - } - - fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - self.table.delete(this)?; - Ok(()) - } -} - -enum StreamState { - Ready(T), - Pending(Pin + Send>>), - Closed, -} - -/// Wrapper around Input and Output wasi IO Stream that provides Async Read/Write -pub struct WasiStreams { - input: StreamState, - output: StreamState, -} - -impl AsyncWrite for WasiStreams { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - loop { - match &mut self.as_mut().output { - StreamState::Closed => unreachable!(), - StreamState::Pending(future) => { - let value = ready!(future.as_mut().poll(cx)); - self.as_mut().output = StreamState::Ready(value); - } - StreamState::Ready(output) => { - match output.check_write() { - Ok(0) => { - let StreamState::Ready(mut output) = - mem::replace(&mut self.as_mut().output, StreamState::Closed) - else { - unreachable!() - }; - self.as_mut().output = StreamState::Pending(Box::pin(async move { - output.ready().await; - output - })); - } - Ok(count) => { - let count = count.min(buf.len()); - return match output.write(Bytes::copy_from_slice(&buf[..count])) { - Ok(()) => Poll::Ready(Ok(count)), - Err(StreamError::Closed) => Poll::Ready(Ok(0)), - Err(e) => Poll::Ready(Err(std::io::Error::other(e))), - }; - } - Err(StreamError::Closed) => { - // Our current version of tokio-rustls does not handle returning `Ok(0)` well. - // See: https://github.com/rustls/tokio-rustls/issues/92 - return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); - } - Err(e) => return Poll::Ready(Err(std::io::Error::other(e))), - }; - } - } - } - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.poll_write(cx, &[]).map(|v| v.map(drop)) +impl WasiTlsCtxBuilder { + /// Creates a builder for a new context with default parameters set. + pub fn new() -> Self { + Default::default() } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.poll_flush(cx) + /// Configure the TLS provider to use for this context. + /// + /// By default, this is set to the [`RustlsProvider`]. + pub fn provider(mut self, provider: Box) -> Self { + self.provider = provider; + self } -} - -impl AsyncRead for WasiStreams { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - loop { - let stream = match &mut self.input { - StreamState::Ready(stream) => stream, - StreamState::Pending(fut) => { - let stream = ready!(fut.as_mut().poll(cx)); - self.input = StreamState::Ready(stream); - if let StreamState::Ready(stream) = &mut self.input { - stream - } else { - unreachable!() - } - } - StreamState::Closed => { - return Poll::Ready(Ok(())); - } - }; - match stream.read(buf.remaining()) { - Ok(bytes) if bytes.is_empty() => { - let StreamState::Ready(mut stream) = - std::mem::replace(&mut self.input, StreamState::Closed) - else { - unreachable!() - }; - - self.input = StreamState::Pending(Box::pin(async move { - stream.ready().await; - stream - })); - } - Ok(bytes) => { - buf.put_slice(&bytes); - return Poll::Ready(Ok(())); - } - Err(StreamError::Closed) => { - self.input = StreamState::Closed; - return Poll::Ready(Ok(())); - } - Err(e) => { - self.input = StreamState::Closed; - return Poll::Ready(Err(std::io::Error::other(e))); - } - } + /// Uses the configured context so far to construct the final [`WasiTlsCtx`]. + pub fn build(self) -> WasiTlsCtx { + WasiTlsCtx { + provider: self.provider, } } } - -type TlsWriteHalf = tokio::io::WriteHalf>; - -struct TlsWriter { - state: WriteState, -} - -enum WriteState { - Ready(TlsWriteHalf), - Writing(AbortOnDropJoinHandle>), - Closing(AbortOnDropJoinHandle>), - Closed, - Error(io::Error), -} -const READY_SIZE: usize = 1024 * 1024 * 1024; - -impl TlsWriter { - fn new(stream: TlsWriteHalf) -> Self { +impl Default for WasiTlsCtxBuilder { + fn default() -> Self { Self { - state: WriteState::Ready(stream), - } - } - - fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> { - let WriteState::Ready(_) = self.state else { - return Err(StreamError::Trap(anyhow::anyhow!( - "unpermitted: must call check_write first" - ))); - }; - - if bytes.is_empty() { - return Ok(()); - } - - let WriteState::Ready(mut stream) = std::mem::replace(&mut self.state, WriteState::Closed) - else { - unreachable!() - }; - - self.state = WriteState::Writing(wasmtime_wasi::runtime::spawn(async move { - while !bytes.is_empty() { - let n = stream.write(&bytes).await?; - let _ = bytes.split_to(n); - } - - Ok(stream) - })); - - Ok(()) - } - - fn flush(&mut self) -> Result<(), StreamError> { - // `flush` is a no-op here, as we're not managing any internal buffer. - match self.state { - WriteState::Ready(_) - | WriteState::Writing(_) - | WriteState::Closing(_) - | WriteState::Error(_) => Ok(()), - WriteState::Closed => Err(StreamError::Closed), - } - } - - fn check_write(&mut self) -> Result { - match &mut self.state { - WriteState::Ready(_) => Ok(READY_SIZE), - WriteState::Writing(_) => Ok(0), - WriteState::Closing(_) => Ok(0), - WriteState::Closed => Err(StreamError::Closed), - WriteState::Error(_) => { - let WriteState::Error(e) = std::mem::replace(&mut self.state, WriteState::Closed) - else { - unreachable!() - }; - - Err(StreamError::LastOperationFailed(e.into())) - } - } - } - - fn close(&mut self) { - match std::mem::replace(&mut self.state, WriteState::Closed) { - // No write in progress, immediately shut down: - WriteState::Ready(mut stream) => { - self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { - stream.shutdown().await - })); - } - - // Schedule the shutdown after the current write has finished: - WriteState::Writing(write) => { - self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { - let mut stream = write.await?; - stream.shutdown().await - })); - } - - WriteState::Closing(t) => { - self.state = WriteState::Closing(t); - } - WriteState::Closed | WriteState::Error(_) => {} - } - } - - async fn cancel(&mut self) { - match std::mem::replace(&mut self.state, WriteState::Closed) { - WriteState::Writing(task) => _ = task.cancel().await, - WriteState::Closing(task) => _ = task.cancel().await, - _ => {} - } - } - - async fn ready(&mut self) { - match &mut self.state { - WriteState::Writing(task) => { - self.state = match task.await { - Ok(s) => WriteState::Ready(s), - Err(e) => WriteState::Error(e), - } - } - WriteState::Closing(task) => { - self.state = match task.await { - Ok(()) => WriteState::Closed, - Err(e) => WriteState::Error(e), - } - } - _ => {} + provider: Box::new(RustlsProvider::default()), } } } -#[derive(Clone)] -struct AsyncTlsWriteStream(Arc>); - -impl AsyncTlsWriteStream { - fn new(writer: TlsWriter) -> Self { - AsyncTlsWriteStream(Arc::new(Mutex::new(writer))) - } - - fn close(&mut self) -> wasmtime::Result<()> { - try_lock_for_stream(&self.0)?.close(); - Ok(()) - } +/// Wasi TLS context needed for internal `wasi-tls` state. +pub struct WasiTlsCtx { + pub(crate) provider: Box, } -#[async_trait] -impl OutputStream for AsyncTlsWriteStream { - fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> { - try_lock_for_stream(&self.0)?.write(bytes) - } - - fn flush(&mut self) -> Result<(), StreamError> { - try_lock_for_stream(&self.0)?.flush() - } - - fn check_write(&mut self) -> Result { - try_lock_for_stream(&self.0)?.check_write() - } +/// The data stream that carries the encrypted TLS data. +/// Typically this is a TCP stream. +pub trait TlsTransport: AsyncRead + AsyncWrite + Send + Unpin + 'static {} +impl TlsTransport for T {} - async fn cancel(&mut self) { - self.0.lock().await.cancel().await - } -} +/// A TLS connection. +pub trait TlsStream: AsyncRead + AsyncWrite + Send + Unpin + 'static {} -#[async_trait] -impl Pollable for AsyncTlsWriteStream { - async fn ready(&mut self) { - self.0.lock().await.ready().await - } -} - -fn try_lock_for_stream( - mutex: &Mutex, -) -> Result, StreamError> { - mutex - .try_lock() - .map_err(|_| StreamError::trap("concurrent access to resource not supported")) +/// A TLS implementation. +pub trait TlsProvider: Send + Sync + 'static { + /// Set up a client TLS connection using the provided `server_name` and `transport`. + fn connect( + &self, + server_name: String, + transport: Box, + ) -> BoxFuture>>; } -#[cfg(test)] -mod tests { - use super::*; - use std::task::Waker; - use tokio::sync::oneshot; - - #[tokio::test] - async fn test_future_client_streams_ready_can_be_canceled() { - let (tx1, rx1) = oneshot::channel::<()>(); - - let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move { - rx1.await - .map_err(|_| TlsError::Trap(anyhow::anyhow!("oneshot canceled"))) - }))); - - let mut fut = future_streams.ready(); - - let mut cx = std::task::Context::from_waker(Waker::noop()); - assert!(fut.as_mut().poll(&mut cx).is_pending()); - - //cancel the readiness check - drop(fut); - - match future_streams.0 { - StreamState::Closed => panic!("First future should be in Pending/ready state"), - _ => (), - } - - // make it ready and wait for it to progress - tx1.send(()).unwrap(); - future_streams.ready().await; - - match future_streams.0 { - StreamState::Ready(Ok(())) => (), - _ => panic!("First future should be in Ready(Err) state"), - } - } -} +pub(crate) type BoxFuture = std::pin::Pin + Send>>; diff --git a/crates/wasi-tls/src/rustls.rs b/crates/wasi-tls/src/rustls.rs new file mode 100644 index 000000000000..0d8fcd2cf22a --- /dev/null +++ b/crates/wasi-tls/src/rustls.rs @@ -0,0 +1,51 @@ +//! The `rustls` provider. + +use rustls::pki_types::ServerName; +use std::io; +use std::sync::{Arc, LazyLock}; + +use crate::{BoxFuture, TlsProvider, TlsStream, TlsTransport}; + +impl crate::TlsStream for tokio_rustls::client::TlsStream> {} + +/// The `rustls` provider. +pub struct RustlsProvider { + client_config: Arc, +} + +impl TlsProvider for RustlsProvider { + fn connect( + &self, + server_name: String, + transport: Box, + ) -> BoxFuture>> { + let client_config = Arc::clone(&self.client_config); + Box::pin(async move { + let domain = ServerName::try_from(server_name) + .map_err(|_| io::Error::other("invalid server name"))?; + + let stream = tokio_rustls::TlsConnector::from(client_config) + .connect(domain, transport) + .await?; + Ok(Box::new(stream) as Box) + }) + } +} + +impl Default for RustlsProvider { + fn default() -> Self { + static CONFIG: LazyLock> = LazyLock::new(|| { + let roots = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + let config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + Arc::new(config) + }); + + Self { + client_config: Arc::clone(&CONFIG), + } + } +} diff --git a/crates/wasi-tls/tests/main.rs b/crates/wasi-tls/tests/main.rs index 21fedacb6492..3105cee3a517 100644 --- a/crates/wasi-tls/tests/main.rs +++ b/crates/wasi-tls/tests/main.rs @@ -1,15 +1,15 @@ use anyhow::{Result, anyhow}; -use test_programs_artifacts::{TLS_SAMPLE_APPLICATION_COMPONENT, foreach_tls}; use wasmtime::{ Store, component::{Component, Linker, ResourceTable}, }; use wasmtime_wasi::p2::{IoView, WasiCtx, WasiCtxBuilder, WasiView, bindings::Command}; -use wasmtime_wasi_tls::{LinkOptions, WasiTlsCtx}; +use wasmtime_wasi_tls::{LinkOptions, WasiTls, WasiTlsCtx, WasiTlsCtxBuilder}; struct Ctx { table: ResourceTable, wasi_ctx: WasiCtx, + wasi_tls_ctx: WasiTlsCtx, } impl IoView for Ctx { @@ -23,7 +23,17 @@ impl WasiView for Ctx { } } -async fn run_wasi(path: &str, ctx: Ctx) -> Result<()> { +async fn run_test(path: &str) -> Result<()> { + let ctx = Ctx { + table: ResourceTable::new(), + wasi_ctx: WasiCtxBuilder::new() + .inherit_stderr() + .inherit_network() + .allow_ip_name_lookup(true) + .build(), + wasi_tls_ctx: WasiTlsCtxBuilder::new().build(), + }; + let engine = test_programs_artifacts::engine(|config| { config.async_support(true); }); @@ -35,7 +45,7 @@ async fn run_wasi(path: &str, ctx: Ctx) -> Result<()> { let mut opts = LinkOptions::default(); opts.tls(true); wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| { - WasiTlsCtx::new(&mut h.table) + WasiTls::new(&h.wasi_tls_ctx, &mut h.table) })?; let command = Command::instantiate_async(&mut store, &component, &linker).await?; @@ -53,20 +63,9 @@ macro_rules! assert_test_exists { }; } -foreach_tls!(assert_test_exists); +test_programs_artifacts::foreach_tls!(assert_test_exists); #[tokio::test(flavor = "multi_thread")] async fn tls_sample_application() -> Result<()> { - run_wasi( - TLS_SAMPLE_APPLICATION_COMPONENT, - Ctx { - table: ResourceTable::new(), - wasi_ctx: WasiCtxBuilder::new() - .inherit_stderr() - .inherit_network() - .allow_ip_name_lookup(true) - .build(), - }, - ) - .await + run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await } diff --git a/scripts/publish.rs b/scripts/publish.rs index f34c563cfab4..f3ea0247d87d 100644 --- a/scripts/publish.rs +++ b/scripts/publish.rs @@ -80,6 +80,7 @@ const CRATES_TO_PUBLISH: &[&str] = &[ "wasmtime-wasi-keyvalue", "wasmtime-wasi-threads", "wasmtime-wasi-tls", + "wasmtime-wasi-tls-nativetls", "wasmtime-wast", "wasmtime-internal-c-api-macros", "wasmtime-c-api-impl", @@ -99,6 +100,7 @@ const PUBLIC_CRATES: &[&str] = &[ "wasmtime-wasi-io", "wasmtime-wasi", "wasmtime-wasi-tls", + "wasmtime-wasi-tls-nativetls", "wasmtime-wasi-http", "wasmtime-wasi-nn", "wasmtime-wasi-config", diff --git a/src/commands/run.rs b/src/commands/run.rs index d71a685b72a2..e4e13d23b7ad 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -33,7 +33,7 @@ use wasmtime_wasi_http::{ use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx, WasiKeyValueCtxBuilder}; #[cfg(feature = "wasi-tls")] -use wasmtime_wasi_tls::WasiTlsCtx; +use wasmtime_wasi_tls::{WasiTls, WasiTlsCtx}; fn parse_preloads(s: &str) -> Result<(String, PathBuf)> { let parts: Vec<&str> = s.splitn(2, '=').collect(); @@ -993,8 +993,14 @@ impl RunCommand { h.preview2_ctx.as_mut().expect("wasip2 is not configured"); let preview2_ctx = Arc::get_mut(preview2_ctx).unwrap().get_mut().unwrap(); - WasiTlsCtx::new(preview2_ctx.table()) + WasiTls::new( + Arc::get_mut(h.wasi_tls.as_mut().unwrap()).unwrap(), + preview2_ctx.table(), + ) })?; + + let ctx = wasmtime_wasi_tls::WasiTlsCtxBuilder::new().build(); + store.data_mut().wasi_tls = Some(Arc::new(ctx)); } } } @@ -1105,6 +1111,8 @@ struct Host { wasi_config: Option>, #[cfg(feature = "wasi-keyvalue")] wasi_keyvalue: Option>, + #[cfg(feature = "wasi-tls")] + wasi_tls: Option>, } impl Host { diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index c172179e5b93..ae8f8f94fcd8 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -622,6 +622,13 @@ user-id = 6743 user-login = "epage" user-name = "Ed Page" +[[publisher.core-foundation]] +version = "0.9.3" +when = "2022-02-07" +user-id = 5946 +user-login = "jrmuizel" +user-name = "Jeff Muizelaar" + [[publisher.core-foundation-sys]] version = "0.8.4" when = "2023-04-03" @@ -931,6 +938,13 @@ user-id = 189 user-login = "BurntSushi" user-name = "Andrew Gallant" +[[publisher.openssl-probe]] +version = "0.1.6" +when = "2025-01-23" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.prettyplease]] version = "0.2.31" when = "2025-03-13"