diff --git a/Cargo.lock b/Cargo.lock index f4b3a4aa..64c22e20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3593,6 +3593,12 @@ dependencies = [ "url", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "target-lexicon" version = "0.13.5" @@ -4295,10 +4301,13 @@ dependencies = [ name = "ts_elixir" version = "0.2.0" dependencies = [ + "futures-util", "rustler", "tailscale", + "tap", "tokio", "tracing", + "tracing-subscriber", ] [[package]] diff --git a/ts_elixir/config/config.exs b/ts_elixir/config/config.exs new file mode 100644 index 00000000..a43c1d18 --- /dev/null +++ b/ts_elixir/config/config.exs @@ -0,0 +1,7 @@ +import Config + +config :tailscale, + testing_nifs: false, + profile: :debug + +import_config "#{config_env()}.exs" diff --git a/ts_elixir/config/dev.exs b/ts_elixir/config/dev.exs new file mode 100644 index 00000000..03b03c6f --- /dev/null +++ b/ts_elixir/config/dev.exs @@ -0,0 +1,4 @@ +import Config + +config :tailscale, + testing_nifs: true diff --git a/ts_elixir/config/prod.exs b/ts_elixir/config/prod.exs new file mode 100644 index 00000000..e58378b4 --- /dev/null +++ b/ts_elixir/config/prod.exs @@ -0,0 +1,4 @@ +import Config + +config :tailscale, + profile: :release diff --git a/ts_elixir/config/test.exs b/ts_elixir/config/test.exs new file mode 100644 index 00000000..03b03c6f --- /dev/null +++ b/ts_elixir/config/test.exs @@ -0,0 +1,4 @@ +import Config + +config :tailscale, + testing_nifs: true diff --git a/ts_elixir/lib/tailscale.ex b/ts_elixir/lib/tailscale.ex index 3582e87a..71dcb47f 100644 --- a/ts_elixir/lib/tailscale.ex +++ b/ts_elixir/lib/tailscale.ex @@ -1,4 +1,6 @@ defmodule Tailscale do + require Tailscale.Util + @moduledoc """ Elixir bindings for the Tailscale Rust client. @@ -66,8 +68,8 @@ defmodule Tailscale do See `t:options/0` for details on available options. """ - def connect(key_file_path, options) when is_binary(key_file_path) do - case Tailscale.Native.load_key_file(key_file_path) do + def connect(key_file_path, options) when is_binary(key_file_path) and is_list(options) do + case Tailscale.Util.await(Tailscale.Native.load_key_file(key_file_path)) do {:ok, keys} -> Keyword.put(options, :keys, keys) |> connect() @@ -86,8 +88,10 @@ defmodule Tailscale do """ def connect(options \\ []) - def connect(options) when is_list(options), - do: :proplists.to_map(options) |> Tailscale.Native.connect() + def connect(options) when is_list(options) do + options = :proplists.to_map(options) + Tailscale.Util.await(Tailscale.Native.connect(options)) + end def connect(key_file_path) when is_binary(key_file_path), do: connect(key_file_path, []) @@ -97,7 +101,7 @@ defmodule Tailscale do Blocks until the address is available. """ - def ipv4_addr(dev), do: Tailscale.Native.ipv4_addr(dev) + def ipv4_addr(dev), do: Tailscale.Util.await(Tailscale.Native.ipv4_addr(dev)) @spec ipv6_addr(t()) :: {:ok, :inet.ip6_address()} | {:error, any()} @doc """ @@ -108,13 +112,13 @@ defmodule Tailscale do Note that this address is in `t::inet.ip6_address/0` format (16-bit segments), which may be difficult to read. See `:inet.ntoa/1` to format to a string. """ - def ipv6_addr(dev), do: Tailscale.Native.ipv6_addr(dev) + def ipv6_addr(dev), do: Tailscale.Util.await(Tailscale.Native.ipv6_addr(dev)) @spec self_node(t()) :: {:ok, Tailscale.NodeInfo.t()} | {:error, any()} @doc """ Get this node's `m:Tailscale.NodeInfo`. """ - defdelegate self_node(dev), to: Tailscale.Native + def self_node(dev), do: Tailscale.Util.await(Tailscale.Native.self_node(dev)) @spec peer_by_name(t(), String.t()) :: {:ok, Tailscale.NodeInfo.t() | nil} | {:error, any()} @doc """ @@ -123,7 +127,8 @@ defmodule Tailscale do Returns `{:ok, nil}` if there was no such peer, and `{:error, reason}` if the lookup encountered an error. """ - def peer_by_name(dev, name), do: Tailscale.Native.peer_by_name(dev, name) + def peer_by_name(dev, name), + do: Tailscale.Util.await(Tailscale.Native.peer_by_name(dev, name)) @spec peer_by_tailnet_ip(t(), Tailscale.ip_addr()) :: {:ok, Tailscale.NodeInfo.t() | nil} | {:error, any()} @@ -132,12 +137,14 @@ defmodule Tailscale do Returns `{:ok, nil}` if there was no such peer. `:error` if the lookup encountered an error. """ - defdelegate peer_by_tailnet_ip(dev, ip), to: Tailscale.Native + def peer_by_tailnet_ip(dev, ip), + do: Tailscale.Util.await(Tailscale.Native.peer_by_tailnet_ip(dev, ip)) @spec peers_with_route(t(), Tailscale.ip_addr()) :: {:ok, [Tailscale.NodeInfo.t()]} | {:error, any()} @doc """ Retrieve the most narrow set of peers that accept packets for the specified IP. """ - defdelegate peers_with_route(dev, ip), to: Tailscale.Native + def peers_with_route(dev, ip), + do: Tailscale.Util.await(Tailscale.Native.peers_with_route(dev, ip)) end diff --git a/ts_elixir/lib/tailscale/native.ex b/ts_elixir/lib/tailscale/native.ex index abd7dfdd..eed11358 100644 --- a/ts_elixir/lib/tailscale/native.ex +++ b/ts_elixir/lib/tailscale/native.ex @@ -1,9 +1,20 @@ defmodule Tailscale.Native do + @moduledoc false + + @testing_nifs Application.compile_env!(:tailscale, :testing_nifs) + @profile Application.compile_env!(:tailscale, :profile) + + @features (if @testing_nifs do + ["testing-nifs"] + else + [] + end) + use Rustler, otp_app: :tailscale, - crate: :ts_elixir - - @moduledoc false + crate: :ts_elixir, + mode: @profile, + features: @features # The Elixir side of the Rustler bindings to `tailscale-rs`. # @@ -32,6 +43,31 @@ defmodule Tailscale.Native do """ @opaque tcp_stream :: reference() + @typedoc """ + NIFs provided here may have asynchronous effects that would typically block and require the use of + the DirtyIO scheduler. This is undesirable as we may have a large number of concurrent calls into + the NIFs, which could exhaust the DirtyIO thread pool. Instead, we use message passing on the Rust + side to send replies back into the BEAM. Functions that use this model return `async_reply` + without blocking. The `:async` case means the reply will be sent asynchronously using a message of + the format `{:tailscale, REF, PAYLOAD}`, where `REF` is the reference associated with the `:async` + response, guaranteed unique per call. + + The `:error` response means that an error was encountered before dispatching the asynchronous + call. + + The `:nif_panic` response means that the NIF panicked during execution; the second parameter is + the reason for the panic (if given). + + `{:raise, TERM}` means `TERM` should be raised as an exception. + + `m:Tailscale.Util` has helpers for decoding messages of this form. + """ + @type async_reply() :: + {:async, reference()} + | {:error, any()} + | {:nif_panic, String.t() | {}} + | {:raise, any()} + defp err, do: :erlang.nif_error(:nif_not_loaded) @doc """ @@ -39,7 +75,7 @@ defmodule Tailscale.Native do See `t:Tailscale.options/0` for details on what options are supported. """ - @spec connect(%{}) :: {:ok, device()} | {:error, any()} + @spec connect(%{}) :: async_reply() def connect(_opts), do: err() @doc """ @@ -51,7 +87,7 @@ defmodule Tailscale.Native do - `port`: the port to which the socket should bind. """ @spec udp_bind(device(), Tailscale.ip_addr() | :ip4 | :ip6, :inet.port_number()) :: - {:ok, udp_socket()} | {:error, any()} + async_reply() def udp_bind(_dev, _addr, _port), do: err() @doc """ @@ -65,14 +101,14 @@ defmodule Tailscale.Native do - `msg`: the packet to send. """ @spec udp_send(udp_socket(), Tailscale.ip_addr(), :inet.port_number(), binary()) :: - :ok | {:error, any()} + async_reply() def udp_send(_sock, _ip, _port, _msg), do: err() @doc """ Receive an incoming UDP packet on the given socket. """ @spec udp_recv(udp_socket()) :: - {:ok, :inet.ip_address(), :inet.port_number(), binary()} | {:error, any()} + async_reply() def udp_recv(_sock), do: err() @doc """ @@ -92,7 +128,7 @@ defmodule Tailscale.Native do Start a TCP listener on the given device, address, and port. """ @spec tcp_listen(device(), Tailscale.ip_addr() | :ip4 | :ip6, :inet.port_number()) :: - {:ok, tcp_listener()} | {:error, any()} + async_reply() def tcp_listen(_dev, _addr, _port), do: err() @doc """ @@ -105,13 +141,13 @@ defmodule Tailscale.Native do Connect to the given TCP endpoint using the given device. """ @spec tcp_connect(device(), Tailscale.ip_addr(), :inet.port_number()) :: - {:ok, tcp_stream()} | {:error, any()} + async_reply() def tcp_connect(_dev, _addr, _port), do: err() @doc """ Accept an incoming TCP connection. Blocks until one is available. """ - @spec tcp_accept(tcp_listener()) :: {:ok, tcp_stream()} | {:error, any()} + @spec tcp_accept(tcp_listener()) :: async_reply() def tcp_accept(_listener), do: err() @doc """ @@ -120,13 +156,13 @@ defmodule Tailscale.Native do Returns the number of bytes actually written to the remote. """ - @spec tcp_send(tcp_stream(), binary()) :: {:ok, integer()} | {:error, any()} + @spec tcp_send(tcp_stream(), binary()) :: async_reply() def tcp_send(_stream, _msg), do: err() @doc """ Receive incoming data from the tcp socket, blocking until at least one byte can be received. """ - @spec tcp_recv(tcp_stream()) :: {:ok, binary()} | {:error, any()} + @spec tcp_recv(tcp_stream()) :: async_reply() def tcp_recv(_stream), do: err() @doc """ @@ -146,7 +182,7 @@ defmodule Tailscale.Native do Blocks until the device is connected and gets its address from control. """ - @spec ipv4_addr(device()) :: {:ok, :inet.ip4_address()} | {:error, any()} + @spec ipv4_addr(device()) :: async_reply() def ipv4_addr(_dev), do: err() @doc """ @@ -154,36 +190,68 @@ defmodule Tailscale.Native do Blocks until the device is connected and gets its address from control. """ - @spec ipv6_addr(device()) :: {:ok, :inet.ip6_address()} | {:error, any()} + @spec ipv6_addr(device()) :: async_reply() def ipv6_addr(_dev), do: err() @doc """ Retrieve a peer by name. """ - @spec peer_by_name(device(), String.t()) :: {:ok, %{} | nil} | {:error, any()} + @spec peer_by_name(device(), String.t()) :: async_reply() def peer_by_name(_dev, _name), do: err() @doc """ Retrieve this node's info """ - @spec self_node(device()) :: {:ok, %{}} | {:error, any()} + @spec self_node(device()) :: async_reply() def self_node(_dev), do: err() @doc """ Retrieve a peer by its tailnet IP. """ - @spec peer_by_tailnet_ip(device(), Tailscale.ip_addr()) :: {:ok, %{} | nil} | {:error, any()} + @spec peer_by_tailnet_ip(device(), Tailscale.ip_addr()) :: async_reply() def peer_by_tailnet_ip(_dev, _ip), do: err() @doc """ Retrieve the most narrow set of peers that accept packets for the specified IP. """ - @spec peers_with_route(device(), Tailscale.ip_addr()) :: {:ok, [%{}]} | {:error, any()} + @spec peers_with_route(device(), Tailscale.ip_addr()) :: async_reply() def peers_with_route(_dev, _ip), do: err() @doc """ Load key state from the specified path, generating a new state if the file doesn't exist. """ - @spec load_key_file(String.t()) :: {:ok, Tailscale.Keystate.t()} | {:error, any()} + @spec load_key_file(String.t()) :: async_reply() def load_key_file(_path), do: err() + + @doc """ + Raise a `:badarg` exception. + """ + @spec raise_badarg() :: nil + def raise_badarg(), do: err() + + if @testing_nifs do + @doc """ + DEV ONLY: trigger an async panic in the Rust code with the given message (if provided). + """ + @spec async_panic(String.t() | nil) :: async_reply() + def async_panic(_msg \\ nil), do: err() + + @doc """ + DEV ONLY: trigger a raised exception in the Rust code with the given message. + """ + @spec async_raise(String.t(), boolean()) :: async_reply() + def async_raise(_msg, _atom \\ false), do: err() + + @doc """ + DEV ONLY: trigger an asynchronous error in the Rust code with the given message. + """ + @spec async_error(String.t(), boolean()) :: async_reply() + def async_error(_msg, _atom \\ false), do: err() + + @doc """ + DEV ONLY: trigger an asynchronous `:badarg` in the Rust code with the given message. + """ + @spec async_badarg() :: async_reply() + def async_badarg(), do: err() + end end diff --git a/ts_elixir/lib/tailscale/tcp.ex b/ts_elixir/lib/tailscale/tcp.ex index a51a378d..3d63bc37 100644 --- a/ts_elixir/lib/tailscale/tcp.ex +++ b/ts_elixir/lib/tailscale/tcp.ex @@ -1,4 +1,6 @@ defmodule Tailscale.Tcp do + require Tailscale.Util + @moduledoc """ Functionality to create tailscale TCP sockets. @@ -19,7 +21,7 @@ defmodule Tailscale.Tcp do @spec listen(Tailscale.t(), Tailscale.ip_addr() | :ip4 | :ip6, :inet.port_number()) :: {:ok, Tailscale.Tcp.Listener.t()} | {:error, any()} def listen(dev, addr, port) do - Tailscale.Native.tcp_listen(dev, addr, port) + Tailscale.Util.await(Tailscale.Native.tcp_listen(dev, addr, port)) end @doc """ @@ -28,6 +30,6 @@ defmodule Tailscale.Tcp do @spec connect(Tailscale.t(), Tailscale.ip_addr(), :inet.port_number()) :: {:ok, Tailscale.Tcp.Stream.t()} | {:error, any()} def connect(dev, addr, port) do - Tailscale.Native.tcp_connect(dev, addr, port) + Tailscale.Util.await(Tailscale.Native.tcp_connect(dev, addr, port)) end end diff --git a/ts_elixir/lib/tailscale/tcp/listener.ex b/ts_elixir/lib/tailscale/tcp/listener.ex index c26da408..d15889a9 100644 --- a/ts_elixir/lib/tailscale/tcp/listener.ex +++ b/ts_elixir/lib/tailscale/tcp/listener.ex @@ -1,4 +1,6 @@ defmodule Tailscale.Tcp.Listener do + require Tailscale.Util + @moduledoc """ Tailscale TCP listening socket functionality. """ @@ -15,7 +17,7 @@ defmodule Tailscale.Tcp.Listener do Blocks until a connection is ready. """ def accept(res) do - Tailscale.Native.tcp_accept(res) + Tailscale.Util.await(Tailscale.Native.tcp_accept(res)) end @doc """ diff --git a/ts_elixir/lib/tailscale/tcp/stream.ex b/ts_elixir/lib/tailscale/tcp/stream.ex index 189ee9f3..58184574 100644 --- a/ts_elixir/lib/tailscale/tcp/stream.ex +++ b/ts_elixir/lib/tailscale/tcp/stream.ex @@ -3,6 +3,8 @@ defmodule Tailscale.Tcp.Stream do Tailscale TCP sockets (connected). """ + require Tailscale.Util + @typedoc """ A handle to a TCP stream (connected socket). """ @@ -15,7 +17,7 @@ defmodule Tailscale.Tcp.Stream do Returns the number of bytes actually sent. """ def send(res, msg) do - Tailscale.Native.tcp_send(res, msg) + Tailscale.Util.await(Tailscale.Native.tcp_send(res, msg)) end @spec send_all(t(), binary()) :: :ok | {:error, any()} @@ -27,7 +29,7 @@ defmodule Tailscale.Tcp.Stream do case Tailscale.Tcp.Stream.send(res, msg) do {:ok, ^len} -> :ok - {:ok, n} -> Tailscale.Tcp.Stream.send_all(res, binary_slice(msg, n..len)) + {:ok, n} -> send_all(res, binary_slice(msg, n..len)) err -> err end end @@ -37,7 +39,7 @@ defmodule Tailscale.Tcp.Stream do Receive data from the TCP socket, blocking until at least one byte can be received. """ def recv(res) do - Tailscale.Native.tcp_recv(res) + Tailscale.Util.await(Tailscale.Native.tcp_recv(res)) end @spec local_addr(t()) :: {:inet.ip_address(), :inet.port_number()} diff --git a/ts_elixir/lib/tailscale/udp.ex b/ts_elixir/lib/tailscale/udp.ex index 4d6a5993..d97720c3 100644 --- a/ts_elixir/lib/tailscale/udp.ex +++ b/ts_elixir/lib/tailscale/udp.ex @@ -1,4 +1,6 @@ defmodule Tailscale.Udp do + require Tailscale.Util + @moduledoc """ Tailscale UDP sockets. """ @@ -21,7 +23,7 @@ defmodule Tailscale.Udp do - `port`: the port number to bind. """ def bind(dev, addr, port) do - Tailscale.Native.udp_bind(dev, addr, port) + Tailscale.Util.await(Tailscale.Native.udp_bind(dev, addr, port)) end @spec send(t(), Tailscale.ip_addr(), :inet.port_number(), binary()) :: :ok | {:error, any()} @@ -37,7 +39,7 @@ defmodule Tailscale.Udp do - `payload`: the message payload. """ def send(sock, remote, port, payload) do - Tailscale.Native.udp_send(sock, remote, port, payload) + Tailscale.Util.await(Tailscale.Native.udp_send(sock, remote, port, payload)) end @spec recv(t()) :: {:ok, Tailscale.ip_addr(), :inet.port_number(), binary()} | {:error, any()} @@ -45,7 +47,7 @@ defmodule Tailscale.Udp do Receive a packet from the socket, blocking until one is ready. """ def recv(sock) do - Tailscale.Native.udp_recv(sock) + Tailscale.Util.await(Tailscale.Native.udp_recv(sock)) end @doc """ diff --git a/ts_elixir/lib/tailscale/util.ex b/ts_elixir/lib/tailscale/util.ex new file mode 100644 index 00000000..33e4ed37 --- /dev/null +++ b/ts_elixir/lib/tailscale/util.ex @@ -0,0 +1,59 @@ +defmodule Tailscale.Util do + @moduledoc false + # Internal utilities. + + @doc """ + Helper to await a Rust-side-async function that responds via message passing. + + Assumes the callee `block` returns the `:async` branch of `t:Tailscale.Native.async_reply/0`. Any + other response is returned verbatim, assumed to be an error. + """ + defmacro await(block, timeout \\ :infinity) do + quote do + Task.async(fn -> + Tailscale.Util.await_local(unquote(block), :infinity) + end) + |> Task.await(unquote(timeout)) + end + end + + @doc """ + Helper to await a Rust-side-async function that responds via message passing. + + Assumes the callee `block` returns the `:async` branch of `t:Tailscale.Native.async_reply/0`. Any + other response is returned verbatim, assumed to be an error. + + This macro (unlike `Tailscale.Util.await/2`) awaits a response message in the current process + without spawning a `m:Task`. This may be desirable to avoid the slight overhead of spawning a new + process, but may not be preferred if this process's mailbox is likely to be busy. + """ + defmacro await_local(block, timeout \\ :infinity) do + quote do + case unquote(block) do + {:async, ref} -> + receive do + {{:tailscale, ^ref}, result} -> result + after + unquote(timeout) -> + {:error, :timeout} + end + + other -> + other + end + |> Tailscale.Util.normalize_result() + end + end + + @doc """ + Normalize an async result to a standard Elixir-shaped return. + """ + def normalize_result({:ok, _} = result), do: normalize_tuple(result) + def normalize_result({:nif_panic, _} = result), do: {:error, normalize_tuple(result)} + def normalize_result({:raise, :badarg}), do: Tailscale.Native.raise_badarg() + def normalize_result({:raise, t}), do: raise(t) + def normalize_result(otherwise), do: otherwise + + defp normalize_tuple({a, {}}), do: a + defp normalize_tuple(a), do: a +end diff --git a/ts_elixir/native/ts_elixir/Cargo.toml b/ts_elixir/native/ts_elixir/Cargo.toml index e5c660d6..67bc94d3 100644 --- a/ts_elixir/native/ts_elixir/Cargo.toml +++ b/ts_elixir/native/ts_elixir/Cargo.toml @@ -11,12 +11,19 @@ license.workspace = true rust-version.workspace = true [dependencies] -rustler = "0.37.2" - tailscale = { workspace = true } +futures-util.workspace = true +rustler = "0.37.2" +tap = "1.0" tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[features] +# Additional testing functions that can directly trigger panics and errors to exercise the code in +# development. +testing-nifs = [] [lib] crate-type = ["cdylib"] diff --git a/ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py b/ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py index e5ed63e6..72492424 100644 --- a/ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py +++ b/ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py @@ -60,7 +60,7 @@ def main(): for dep in list(cargotoml[name].keys()): value = cargotoml[name][dep] - if type(value) == dict and value['workspace'] is True: + if type(value) == dict and value.get('workspace') is True: if args.repo_sha and (dep.startswith('tailscale') or dep.startswith('ts_')): value['git'] = f'https://github.com/tailscale/tailscale-rs' value['rev'] = args.repo_sha diff --git a/ts_elixir/native/ts_elixir/src/async_reply.rs b/ts_elixir/native/ts_elixir/src/async_reply.rs new file mode 100644 index 00000000..15f276d7 --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/async_reply.rs @@ -0,0 +1,160 @@ +//! Facilities for sending asynchronous responses from NIFs. +//! +//! The motivation is that the Erlang DirtyIO scheduler is a thread-pool with inherently +//! limited concurrency (n = threads in the pool), and our NIFs will have to block a whole +//! thread on that pool while they're doing anything, even if it's asynchronous work running +//! on tokio. +//! +//! To avoid that, we adopt a more Erlang/Elixir-oriented approach and respond via message +//! passing. This doesn't block the BEAM at all, specifically because we're essentially +//! interacting with its event loop directly, rather than through the opaque abstraction of +//! a function running on a foreign thread. +//! +//! Our NIFs with async work to do return immediately with `{:async, REF}`, where `REF` is a +//! BEAM reference that uniquely identifies the function invocation. In the background, we +//! do whatever we need to and reply eventually (to the original caller's pid) with a +//! message holding the result of the function call and the original `REF` for correlation. + +use std::panic::AssertUnwindSafe; + +use futures_util::FutureExt; +use rustler::{Encoder, NifResult, OwnedEnv, Term}; + +use crate::{TOKIO_RUNTIME, atoms}; + +pub type AsyncReply<'a> = (rustler::Atom, rustler::Reference<'a>); + +/// Spiritual reimplementation of [`rustler::thread::spawn`] for futures. +/// +/// `fut` is executed in a tokio task, and the result is passed to `post`, which encodes a +/// result to pass back to `pid` as a message. If `fut` or `post` panics, `on_panic` is +/// invoked instead with the encoded reason for the panic, and the returned term is passed +/// back to the calling `pid`. +/// +/// Returns a [`rustler::Reference`] which uniquely identifies this particular spawn call. +/// The same reference is also provided to `post` and `on_panic`; they may or may not choose +/// to make use of it for correlation. +/// +/// NB: this function intentionally does not encode any specifics about the response format. +/// Conventionally, our NIFs respond with `{{:tailscale, REF}, PAYLOAD}`, but this function +/// is general-purpose and doesn't make that assumption: it responds with whatever you +/// tell it to, in the interest of separating concerns. The pieces specific to our current +/// async reply convention are encoded in [`reply_async`] and [`try_reply_async`]. +pub fn spawn( + env: rustler::Env, + fut: F, + post: Post, + on_panic: OnPanic, +) -> rustler::Reference +where + F: Future + Send + 'static, + F::Output: std::panic::UnwindSafe, + Post: for<'env> FnOnce(rustler::Env<'env>, rustler::Reference<'env>, F::Output) -> Term<'env> + + Send + + std::panic::UnwindSafe + + 'static, + OnPanic: for<'env> FnOnce(rustler::Env<'env>, rustler::Reference<'env>, Term<'env>) -> Term<'env> + + Send + + 'static, +{ + let pid = env.pid(); + let ref_ = env.make_ref(); + + let mut env = OwnedEnv::new(); + let saved_ref = env.save(ref_); + + TOKIO_RUNTIME.spawn(async move { + let result = AssertUnwindSafe(fut).catch_unwind().await.and_then(|result| { + std::panic::catch_unwind(|| { + if env.run(|env| { + let ref_ = saved_ref.load(env).decode::().unwrap(); + let value = post(env, ref_, result).encode(env); + + env.send(&pid, value) + }).is_err() { + tracing::error!(target_pid = ?pid.as_c_arg(), "failed sending success reply from spawn, process dead?"); + } + }) + }); + + if let Err(err) = result { + let send_result = env.send_and_clear(&pid, move |env| { + let ref_ = saved_ref.load(env).decode::().unwrap(); + + let reason = if let Some(string) = err.downcast_ref::() { + string.encode(env) + } else if let Some(&s) = err.downcast_ref::<&'static str>() { + s.encode(env) + } else { + ().encode(env) + }; + + on_panic(env, ref_, reason) + }); + + if send_result.is_err() { + tracing::error!(target_pid = ?pid.as_c_arg(), "failed sending panic reply from spawn, process dead?"); + } + } + }); + + ref_ +} + +/// Convenience wrapper for [`spawn`] when the return type is [`crate::Result`], +/// automatically converting the response to a reply +/// `{:ok, TERM} | {:error, TERM} | {:nif_panic, TERM} | {:raise | TERM}` wrapped in +/// `{:tailscale, ref, REPLY}`. +pub fn try_reply_async(env: rustler::Env, fut: F) -> AsyncReply +where + F: Future> + Send + 'static, + T: Encoder, +{ + let ref_ = spawn( + env, + async move { AssertUnwindSafe(fut.await) }, + move |env, ref_, t| { + let resp = match t.0 { + Ok(val) => (atoms::ok(), val).encode(env), + Err(e) => encode_async_err(env, e), + }; + + async_resp(ref_, resp).encode(env) + }, + move |env, ref_, reason| async_resp(ref_, (atoms::nif_panic(), reason)).encode(env), + ); + + (atoms::async_(), ref_) +} + +#[rustler::nif] +fn raise_badarg() -> NifResult<()> { + Err(rustler::Error::BadArg) +} + +pub fn async_resp<'r, T>(ref_: rustler::Reference<'r>, value: T) -> (AsyncReply<'r>, T) { + ((atoms::tailscale(), ref_), value) +} + +/// Encode the given [`rustler::Error`] as a [`Term`]. +/// +/// This is needed because [`rustler::Error`] typically expects to be returned from a NIF, +/// where it can directly raise exceptions on the [`Env`]. We don't want to do that here, we +/// want to forward the exception to raise through message passing. On the Elixir side, +/// `Tailscale.Util.normalize_result` handles converting the value into the correct form +/// (`{:error, TERM}` or a raised exception). +fn encode_async_err(env: rustler::Env, err: rustler::Error) -> Term { + match err { + rustler::Error::Term(b) => (atoms::error(), b.encode(env)).encode(env), + rustler::Error::Atom(a) => match rustler::Atom::from_str(env, a) { + Ok(atom) => env.error_tuple(atom), + Err(_e) => (atoms::raise(), atoms::badarg()).encode(env), + }, + rustler::Error::BadArg => (atoms::raise(), atoms::badarg()).encode(env), + rustler::Error::RaiseAtom(atom) => match rustler::Atom::from_str(env, atom) { + Ok(atom) => (atoms::raise(), atom).encode(env), + Err(_e) => (atoms::raise(), atoms::badarg()).encode(env), + }, + rustler::Error::RaiseTerm(t) => (atoms::raise(), t.encode(env)).encode(env), + } +} diff --git a/ts_elixir/native/ts_elixir/src/config.rs b/ts_elixir/native/ts_elixir/src/config.rs index f03190b4..52ae9ba7 100644 --- a/ts_elixir/native/ts_elixir/src/config.rs +++ b/ts_elixir/native/ts_elixir/src/config.rs @@ -30,14 +30,14 @@ pub fn config_from_erl( config.key_state = value .decode::()? .try_into() - .map_err(|_| rustler::Error::Atom("badkeys"))?; + .map_err(|_| rustler::Error::BadArg)?; } if let Some(value) = erl_config.get(&atoms::control_url()) { config.control_server_url = value.decode::<&str>()?.parse().map_err(|e| { tracing::error!(error = %e, "parsing control server url"); - rustler::Error::Atom("bad_url") + rustler::Error::BadArg })?; } diff --git a/ts_elixir/native/ts_elixir/src/erl_ip.rs b/ts_elixir/native/ts_elixir/src/erl_ip.rs new file mode 100644 index 00000000..97d6027b --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/erl_ip.rs @@ -0,0 +1,93 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + str::FromStr, +}; + +use rustler::{Encoder, NifResult, Term}; + +/// Erlang-formatted IP. +/// +/// Supports decoding from either a string or `:inet` (tuple of octets or segments) format, +/// always encodes into the `:inet` format. +#[derive(Copy, Clone, Debug)] +pub struct ErlIp(pub IpAddr); + +impl From for ErlIp { + fn from(value: Ipv4Addr) -> Self { + Self(value.into()) + } +} + +impl From for ErlIp { + fn from(value: Ipv6Addr) -> Self { + Self(value.into()) + } +} + +impl From for ErlIp { + fn from(value: IpAddr) -> Self { + Self(value) + } +} + +impl From for IpAddr { + fn from(value: ErlIp) -> Self { + value.0 + } +} + +impl<'a> rustler::Decoder<'a> for ErlIp { + fn decode(ip: Term<'a>) -> NifResult { + if let Ok(tuple) = rustler::types::tuple::get_tuple(ip) { + if tuple.len() == 4 { + let mut octets = [0u8; 4]; + + for (i, elem) in tuple.into_iter().take(4).enumerate() { + octets[i] = elem.decode()?; + } + + return Ok(Self(Ipv4Addr::from_octets(octets).into())); + } + + if tuple.len() == 8 { + let mut segments = [0u16; 8]; + + for (i, elem) in tuple.into_iter().take(8).enumerate() { + segments[i] = elem.decode()?; + } + + return Ok(Self(Ipv6Addr::from_segments(segments).into())); + } + } + + if let Ok(s) = ip.decode::<&str>() { + let ip = IpAddr::from_str(s).map_err(|e| { + tracing::error!(error = %e, "parsing ip addr"); + + rustler::Error::BadArg + })?; + + return Ok(Self(ip)); + } + + Err(rustler::Error::BadArg) + } +} + +impl Encoder for ErlIp { + fn encode<'a>(&self, env: rustler::Env<'a>) -> Term<'a> { + match self.0 { + IpAddr::V4(ip) => { + let octets = ip.octets(); + (octets[0], octets[1], octets[2], octets[3]).encode(env) + } + IpAddr::V6(ip) => { + // rustler doesn't provide `impl Encoder` for 8-length tuples + let segments = ip.segments().map(|segment| segment.encode(env)); + + let tuple = rustler::types::tuple::make_tuple(env, &segments); + tuple.encode(env) + } + } + } +} diff --git a/ts_elixir/native/ts_elixir/src/helpers.rs b/ts_elixir/native/ts_elixir/src/helpers.rs new file mode 100644 index 00000000..80d4d393 --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/helpers.rs @@ -0,0 +1,22 @@ +use std::{fmt::Display, net::SocketAddr}; + +use rustler::{NifResult, ResourceArc}; + +use crate::erl_ip::ErlIp; + +/// Wrap the given [`rustler::Resource`] in a [`ResourceArc`] inside a [`NifResult`]. +pub fn ok_arc(t: T) -> NifResult> +where + T: rustler::Resource, +{ + Ok(ResourceArc::new(t)) +} + +/// Convert the argument into a [`rustler::Error`] by making it into a string. +pub fn term_err(e: impl Display) -> rustler::Error { + rustler::Error::Term(Box::new(e.to_string())) +} + +pub fn sockaddr_to_erl(addr: SocketAddr) -> (ErlIp, u16) { + (ErlIp(addr.ip()), addr.port()) +} diff --git a/ts_elixir/native/ts_elixir/src/ip_or_self.rs b/ts_elixir/native/ts_elixir/src/ip_or_self.rs new file mode 100644 index 00000000..f4482d5d --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/ip_or_self.rs @@ -0,0 +1,51 @@ +use std::net::IpAddr; + +use rustler::{Error, NifResult, Term}; + +use crate::{atoms, erl_ip::ErlIp}; + +/// A literal IP address, the atom `:ip4`, or the atom `:ip6`. +/// +/// The latter two mean this node's IPv4 or IPv6 address, respectively. +pub enum IpOrSelf { + Ip(ErlIp), + SelfV4, + SelfV6, +} + +impl<'a> rustler::Decoder<'a> for IpOrSelf { + fn decode(ip: Term<'a>) -> NifResult { + if let Ok(ip) = ip.decode::() { + return Ok(Self::Ip(ip)); + } + + let atom = ip.decode::()?; + if atom == atoms::ip4() { + return Ok(Self::SelfV4); + } + + if atom == atoms::ip6() { + return Ok(Self::SelfV6); + } + + Err(Error::BadArg) + } +} + +impl IpOrSelf { + pub async fn resolve(&self, dev: &tailscale::Device) -> NifResult { + match self { + IpOrSelf::Ip(ip) => Ok(ip.0), + IpOrSelf::SelfV4 => dev + .ipv4_addr() + .await + .map(Into::into) + .map_err(|e| Error::Term(Box::new(e.to_string()))), + IpOrSelf::SelfV6 => dev + .ipv6_addr() + .await + .map(Into::into) + .map_err(|e| Error::Term(Box::new(e.to_string()))), + } + } +} diff --git a/ts_elixir/native/ts_elixir/src/lib.rs b/ts_elixir/native/ts_elixir/src/lib.rs index e8474bfe..87f6cf43 100644 --- a/ts_elixir/native/ts_elixir/src/lib.rs +++ b/ts_elixir/native/ts_elixir/src/lib.rs @@ -2,29 +2,46 @@ use std::{ collections::HashMap, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - str::FromStr, - sync::{Arc, LazyLock}, + sync::{Arc, LazyLock, Once}, }; -use rustler::{Encoder, NifResult, ResourceArc, Term}; +use rustler::{NifResult, ResourceArc, Term}; +use tap::Pipe; +use tracing::level_filters::LevelFilter; +mod async_reply; mod config; +mod erl_ip; +mod helpers; +mod ip_or_self; +mod node_info; mod tcp; +#[cfg(feature = "testing-nifs")] +mod testing_nifs; mod udp; +use async_reply::{AsyncReply, try_reply_async}; +use config::Keystate; +use erl_ip::ErlIp; +use helpers::{ok_arc, sockaddr_to_erl, term_err}; +use ip_or_self::IpOrSelf; +use node_info::NodeInfo; use tcp::{TcpListener, TcpStream}; use udp::UdpSocket; -use crate::config::Keystate; - mod atoms { rustler::atoms! { ok, + async_ = "async", error, + nif_panic, + badarg, + raise, ip4, ip6, + + tailscale, } } @@ -32,49 +49,6 @@ struct Device { inner: Arc, } -#[derive(rustler::NifStruct)] -#[module = "Tailscale.NodeInfo"] -struct NodeInfo<'a> { - id: i64, - stable_id: String, - hostname: String, - tailnet: Option, - tags: Vec, - tailnet_addresses: Vec>, - derp_region: Option, - node_key: String, - disco_key: Option, - machine_key: Option, - underlay_addresses: Vec>, -} - -impl<'a> NodeInfo<'a> { - fn from_node(env: rustler::Env<'a>, value: tailscale::NodeInfo) -> Self { - Self { - id: value.id, - stable_id: value.stable_id.0, - hostname: value.hostname, - tailnet: value.tailnet, - tags: value.tags, - tailnet_addresses: vec![ - ip_to_erl(env, value.tailnet_address.ipv4.addr()), - ip_to_erl(env, value.tailnet_address.ipv6.addr()), - ], - derp_region: value.derp_region.map(|x| x.0.get()), - node_key: value.node_key.to_string(), - disco_key: value.disco_key.as_ref().map(ToString::to_string), - machine_key: value.machine_key.as_ref().map(ToString::to_string), - underlay_addresses: value - .underlay_addresses - .into_iter() - .map(|x| (ip_to_erl(env, x.ip()), x.port()).encode(env)) - .collect(), - } - } -} - -type Result = core::result::Result>; - #[rustler::resource_impl] impl rustler::Resource for Device {} @@ -89,210 +63,123 @@ static TOKIO_RUNTIME: LazyLock = LazyLock::new(|| { rt }); -fn erl_result(env: rustler::Env, r: Result) -> Term { - match r { - Ok(t) => (atoms::ok(), t).encode(env), - Err(e) => (atoms::error(), e.to_string()).encode(env), - } -} - -fn ok_arc(t: T) -> Result> -where - T: rustler::Resource, -{ - Ok(ResourceArc::new(t)) +#[rustler::nif] +fn start_tracing() { + static TRACING_ONCE: Once = Once::new(); + + TRACING_ONCE.call_once(|| { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + }); } -#[rustler::nif(schedule = "DirtyIo")] +#[rustler::nif] fn connect<'env>( env: rustler::Env<'env>, opts: HashMap>, -) -> NifResult<(rustler::Atom, Term<'env>)> { +) -> NifResult> { let (config, auth_key) = config::config_from_erl(&opts)?; - let dev = TOKIO_RUNTIME.block_on(async move { - let dev = tailscale::Device::new(&config, auth_key).await?; + try_reply_async(env, async move { + let dev = tailscale::Device::new(&config, auth_key) + .await + .map_err(term_err)?; ok_arc(Device { inner: Arc::new(dev), }) - }); - - match dev { - Ok(dev) => Ok((atoms::ok(), dev.encode(env))), - Err(e) => Err(rustler::Error::Term(Box::new(e.to_string()))), - } + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn load_key_file(env: rustler::Env, path: &str) -> impl Encoder { - let result = TOKIO_RUNTIME - .block_on(tailscale::config::load_key_file(path, Default::default())) - .map(|keys| { - let keys: tailscale::keys::NodeState = keys.into(); - let result: Keystate = keys.into(); - result - }) - .map_err(Into::into); - - erl_result(env, result) +#[rustler::nif] +fn load_key_file(env: rustler::Env, path: String) -> AsyncReply { + try_reply_async(env, async move { + tailscale::config::load_key_file(path, Default::default()) + .await + .map(tailscale::keys::NodeState::from) + .map(Keystate::from) + .map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn ipv4_addr(env: rustler::Env, dev: ResourceArc) -> impl Encoder { +#[rustler::nif] +fn ipv4_addr(env: rustler::Env, dev: ResourceArc) -> AsyncReply { let dev = dev.inner.clone(); - let addr = TOKIO_RUNTIME.block_on(dev.ipv4_addr()); - erl_result(env, addr.map(|ip| ip_to_erl(env, ip)).map_err(Into::into)) + try_reply_async(env, async move { + dev.ipv4_addr().await.map(ErlIp::from).map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn ipv6_addr(env: rustler::Env<'_>, dev: ResourceArc) -> impl Encoder { +#[rustler::nif] +fn ipv6_addr(env: rustler::Env<'_>, dev: ResourceArc) -> AsyncReply<'_> { let dev = dev.inner.clone(); - match TOKIO_RUNTIME.block_on(dev.ipv6_addr()) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(ip) => (atoms::ok(), ip_to_erl(env, ip)).encode(env), - } + try_reply_async(env, async move { + dev.ipv6_addr().await.map(ErlIp::from).map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn peer_by_name(env: rustler::Env<'_>, dev: ResourceArc, name: &str) -> impl Encoder { +#[rustler::nif] +fn peer_by_name<'e>(env: rustler::Env<'e>, dev: ResourceArc, name: &str) -> AsyncReply<'e> { let dev = dev.inner.clone(); let name = name.to_owned(); - match TOKIO_RUNTIME.block_on(async move { dev.peer_by_name(&name).await }) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(None) => (atoms::ok(), Option::<()>::None).encode(env), - Ok(Some(peer)) => (atoms::ok(), NodeInfo::from_node(env, peer)).encode(env), - } + try_reply_async(env, async move { + dev.peer_by_name(&name) + .await + .map(|opt| opt.map(NodeInfo::from)) + .map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn self_node(env: rustler::Env<'_>, dev: ResourceArc) -> impl Encoder { +#[rustler::nif] +fn self_node(env: rustler::Env<'_>, dev: ResourceArc) -> AsyncReply<'_> { let dev = dev.inner.clone(); - match TOKIO_RUNTIME.block_on(async move { dev.self_node().await }) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(peer) => (atoms::ok(), NodeInfo::from_node(env, peer)).encode(env), - } + try_reply_async(env, async move { + dev.self_node().await.map(NodeInfo::from).map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn peer_by_tailnet_ip(env: rustler::Env<'_>, dev: ResourceArc, ip: Term) -> impl Encoder { +#[rustler::nif] +fn peer_by_tailnet_ip<'e>( + env: rustler::Env<'e>, + dev: ResourceArc, + ip: ErlIp, +) -> NifResult> { let dev = dev.inner.clone(); - let Some(ip) = ip_from_erl(ip) else { - return env.error_tuple("invalid ip"); - }; - - match TOKIO_RUNTIME.block_on(async move { dev.peer_by_tailnet_ip(ip).await }) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(None) => (atoms::ok(), Option::<()>::None).encode(env), - Ok(Some(peer)) => (atoms::ok(), NodeInfo::from_node(env, peer)).encode(env), - } -} -#[rustler::nif(schedule = "DirtyIo")] -fn peers_with_route(env: rustler::Env<'_>, dev: ResourceArc, ip: Term) -> impl Encoder { + try_reply_async(env, async move { + dev.peer_by_tailnet_ip(ip.into()) + .await + .map(|x| x.map(NodeInfo::from)) + .map_err(term_err) + }) + .pipe(Ok) +} + +#[rustler::nif] +fn peers_with_route<'e>( + env: rustler::Env<'e>, + dev: ResourceArc, + ip: ErlIp, +) -> NifResult> { let dev = dev.inner.clone(); - let Some(ip) = ip_from_erl(ip) else { - return env.error_tuple("invalid ip"); - }; - - match TOKIO_RUNTIME.block_on(async move { dev.peers_with_route(ip).await }) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(peers) => ( - atoms::ok(), - peers - .into_iter() - .map(|x| NodeInfo::from_node(env, x)) - .collect::>(), - ) - .encode(env), - } -} - -fn ip_to_erl(env: rustler::Env, ip: impl Into) -> Term { - match ip.into() { - IpAddr::V4(ip) => { - let octets = ip.octets(); - (octets[0], octets[1], octets[2], octets[3]).encode(env) - } - IpAddr::V6(ip) => { - // rustler doesn't provide `impl Encoder` for 8-length tuples - let segments = ip.segments().map(|segment| segment.encode(env)); - - let tuple = rustler::types::tuple::make_tuple(env, &segments); - tuple.encode(env) - } - } -} - -enum IpOrSelf { - Ip(IpAddr), - SelfV4, - SelfV6, -} - -impl IpOrSelf { - pub fn new(ip: Term<'_>) -> Option { - if let Some(ip) = ip_from_erl(ip) { - return Some(Self::Ip(ip)); - } - - let atom = ip.decode::().ok()?; - if atom == atoms::ip4() { - return Some(Self::SelfV4); - } - - if atom == atoms::ip6() { - return Some(Self::SelfV6); - } - - None - } - - pub async fn resolve(&self, dev: &tailscale::Device) -> Result { - match self { - IpOrSelf::Ip(ip) => Ok(*ip), - IpOrSelf::SelfV4 => dev.ipv4_addr().await.map(Into::into).map_err(Into::into), - IpOrSelf::SelfV6 => dev.ipv6_addr().await.map(Into::into).map_err(Into::into), - } - } -} - -fn ip_from_erl(ip: Term) -> Option { - if let Ok(tuple) = rustler::types::tuple::get_tuple(ip) { - if tuple.len() == 4 { - let mut octets = [0u8; 4]; - - for (i, elem) in tuple.into_iter().take(4).enumerate() { - octets[i] = elem.decode().ok()?; - } - - return Some(Ipv4Addr::from_octets(octets).into()); - } - - if tuple.len() == 8 { - let mut segments = [0u16; 8]; - - for (i, elem) in tuple.into_iter().take(8).enumerate() { - segments[i] = elem.decode().ok()?; - } - - return Some(Ipv6Addr::from_segments(segments).into()); - } - } - - if let Ok(s) = ip.decode::<&str>() { - return IpAddr::from_str(s).ok(); - } - - None -} -fn sockaddr_to_erl(env: rustler::Env, addr: SocketAddr) -> impl Encoder { - (ip_to_erl(env, addr.ip()), addr.port()) + try_reply_async(env, async move { + dev.peers_with_route(ip.into()) + .await + .map(|peers| peers.into_iter().map(NodeInfo::from).collect::>()) + .map_err(term_err) + }) + .pipe(Ok) } fn load(env: rustler::Env, _term: Term) -> bool { diff --git a/ts_elixir/native/ts_elixir/src/node_info.rs b/ts_elixir/native/ts_elixir/src/node_info.rs new file mode 100644 index 00000000..638a68d8 --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/node_info.rs @@ -0,0 +1,43 @@ +use crate::{erl_ip::ErlIp, helpers::sockaddr_to_erl}; + +/// Info about a Tailscale peer. +#[derive(rustler::NifStruct)] +#[module = "Tailscale.NodeInfo"] +pub struct NodeInfo { + id: i64, + stable_id: String, + hostname: String, + tailnet: Option, + tags: Vec, + tailnet_addresses: Vec, + derp_region: Option, + node_key: String, + disco_key: Option, + machine_key: Option, + underlay_addresses: Vec<(ErlIp, u16)>, +} + +impl From for NodeInfo { + fn from(value: tailscale::NodeInfo) -> Self { + Self { + id: value.id, + stable_id: value.stable_id.0, + hostname: value.hostname, + tailnet: value.tailnet, + tags: value.tags, + tailnet_addresses: vec![ + ErlIp::from(value.tailnet_address.ipv4.addr()), + ErlIp::from(value.tailnet_address.ipv6.addr()), + ], + derp_region: value.derp_region.map(|x| x.0.get()), + node_key: value.node_key.to_string(), + disco_key: value.disco_key.as_ref().map(ToString::to_string), + machine_key: value.machine_key.as_ref().map(ToString::to_string), + underlay_addresses: value + .underlay_addresses + .into_iter() + .map(sockaddr_to_erl) + .collect(), + } + } +} diff --git a/ts_elixir/native/ts_elixir/src/tcp.rs b/ts_elixir/native/ts_elixir/src/tcp.rs index a539d34a..5db2f0a7 100644 --- a/ts_elixir/native/ts_elixir/src/tcp.rs +++ b/ts_elixir/native/ts_elixir/src/tcp.rs @@ -1,8 +1,12 @@ use std::sync::Arc; -use rustler::{Encoder, ResourceArc}; +use rustler::{Encoder, NifResult, ResourceArc}; +use tap::Pipe; -use crate::{IpOrSelf, Result, TOKIO_RUNTIME, atoms, erl_result, ip_from_erl, ok_arc}; +use crate::{ + AsyncReply, IpOrSelf, erl_ip::ErlIp, helpers::term_err, ok_arc, sockaddr_to_erl, + try_reply_async, +}; pub(crate) struct TcpListener { inner: Arc, @@ -18,98 +22,95 @@ impl rustler::Resource for TcpListener {} #[rustler::resource_impl] impl rustler::Resource for TcpStream {} -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_listen( - env: rustler::Env, +#[rustler::nif] +fn tcp_listen<'e>( + env: rustler::Env<'e>, dev: ResourceArc, - addr: rustler::Term, + ip: IpOrSelf, port: u16, -) -> impl Encoder { +) -> NifResult> { let dev = dev.inner.clone(); - let ip = IpOrSelf::new(addr); - let sock = TOKIO_RUNTIME.block_on(async move { - let addr = ip.ok_or("invalid ip addr")?.resolve(&dev).await?; - let sock = dev.tcp_listen((addr, port).into()).await?; + try_reply_async(env, async move { + let addr = ip.resolve(&dev).await?; + let sock = dev + .tcp_listen((addr, port).into()) + .await + .map_err(term_err)?; ok_arc(TcpListener { inner: Arc::new(sock), }) - }); - - erl_result(env, sock) + }) + .pipe(Ok) } #[rustler::nif] -fn tcp_listen_local_addr(env: rustler::Env, listener: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(env, listener.inner.local_addr()) +fn tcp_listen_local_addr(listener: ResourceArc) -> impl Encoder { + sockaddr_to_erl(listener.inner.local_addr()) } -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_connect( - env: rustler::Env<'_>, +#[rustler::nif] +fn tcp_connect<'e>( + env: rustler::Env<'e>, dev: ResourceArc, - addr: rustler::Term, + addr: ErlIp, port: u16, -) -> impl Encoder { - let addr = ip_from_erl(addr); +) -> NifResult> { let dev = dev.inner.clone(); - let sock = TOKIO_RUNTIME.block_on(async move { - let addr = addr.ok_or("invalid ip addr")?; - let sock = dev.tcp_connect((addr, port).into()).await?; + try_reply_async(env, async move { + let sock = dev + .tcp_connect((addr, port).into()) + .await + .map_err(term_err)?; ok_arc(TcpStream { inner: Arc::new(sock), }) - }); - - erl_result(env, sock) + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_accept(env: rustler::Env<'_>, sock: ResourceArc) -> impl Encoder { +#[rustler::nif] +fn tcp_accept(env: rustler::Env<'_>, sock: ResourceArc) -> AsyncReply<'_> { let inner = sock.inner.clone(); - let sock = TOKIO_RUNTIME.block_on(async move { - let stream = inner.accept().await?; + try_reply_async(env, async move { + let stream = inner.accept().await.map_err(term_err)?; ok_arc(TcpStream { inner: Arc::new(stream), }) - }); - - erl_result(env, sock) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_send(env: rustler::Env, sock: ResourceArc, msg: Vec) -> rustler::Term { +#[rustler::nif] +fn tcp_send(env: rustler::Env, sock: ResourceArc, msg: Vec) -> AsyncReply { let inner = sock.inner.clone(); - match TOKIO_RUNTIME.block_on(async move { inner.send(&msg).await }) { - Ok(n) => (atoms::ok(), n).encode(env), - Err(e) => (atoms::error(), e.to_string()).encode(env), - } + try_reply_async(env, async move { inner.send(&msg).await.map_err(term_err) }) } -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_recv(env: rustler::Env, sock: ResourceArc) -> impl Encoder { +#[rustler::nif] +fn tcp_recv(env: rustler::Env, sock: ResourceArc) -> AsyncReply { let inner = sock.inner.clone(); - let buf = TOKIO_RUNTIME.block_on(async move { - let buf = inner.recv_bytes().await?; - Result::<_>::Ok(buf.to_vec()) - }); - - erl_result(env, buf) + try_reply_async(env, async move { + inner + .recv_bytes() + .await + .map(|b| b.to_vec()) + .map_err(term_err) + }) } #[rustler::nif] -fn tcp_local_addr(env: rustler::Env, sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(env, sock.inner.local_addr()) +fn tcp_local_addr(sock: ResourceArc) -> impl Encoder { + sockaddr_to_erl(sock.inner.local_addr()) } #[rustler::nif] -fn tcp_remote_addr(env: rustler::Env, sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(env, sock.inner.remote_addr()) +fn tcp_remote_addr(sock: ResourceArc) -> impl Encoder { + sockaddr_to_erl(sock.inner.remote_addr()) } diff --git a/ts_elixir/native/ts_elixir/src/testing_nifs.rs b/ts_elixir/native/ts_elixir/src/testing_nifs.rs new file mode 100644 index 00000000..d761cfbb --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/testing_nifs.rs @@ -0,0 +1,50 @@ +//! NIFs that intentionally return errors, panic, and raise exceptions. +//! +//! These are intended for testing the async message passing code and require the +//! `testing-nifs` feature flag to be enabled. + +use rustler::{Env, Error}; + +use crate::async_reply::{AsyncReply, try_reply_async}; + +#[rustler::nif] +pub fn async_panic(env: Env, msg: Option) -> AsyncReply { + try_reply_async(env, async move { + if let Some(msg) = msg { + panic!("{msg}"); + } else { + panic!() + } + + // Needed to indicate return type + #[allow(unreachable_code)] + Ok(()) + }) +} + +#[rustler::nif] +pub fn async_error<'e>(env: Env<'e>, s: String, atom: bool) -> AsyncReply<'e> { + try_reply_async(env, async move { + Result::<(), _>::Err(if atom { + Error::Atom(String::leak(s)) + } else { + Error::Term(Box::new(s)) + }) + }) +} + +#[rustler::nif] +pub fn async_raise<'e>(env: Env<'e>, s: String, atom: bool) -> AsyncReply<'e> { + try_reply_async(env, async move { + Result::<(), _>::Err(if atom { + Error::RaiseAtom(String::leak(s)) + } else { + Error::RaiseTerm(Box::new(s)) + }) + }) +} + +#[rustler::nif] +pub fn async_badarg<'e>(env: Env<'e>) -> AsyncReply<'e> { + try_reply_async(env, async move { Result::<(), _>::Err(Error::BadArg) }) +} diff --git a/ts_elixir/native/ts_elixir/src/udp.rs b/ts_elixir/native/ts_elixir/src/udp.rs index 21945328..f2bf5e3f 100644 --- a/ts_elixir/native/ts_elixir/src/udp.rs +++ b/ts_elixir/native/ts_elixir/src/udp.rs @@ -1,9 +1,11 @@ use std::sync::Arc; -use rustler::{Binary, Encoder, ResourceArc, Term}; +use rustler::{Binary, Encoder, NifResult, ResourceArc}; +use tap::Pipe; use crate::{ - Device, IpOrSelf, Result, TOKIO_RUNTIME, atoms, erl_result, ip_from_erl, ip_to_erl, ok_arc, + AsyncReply, Device, IpOrSelf, erl_ip::ErlIp, helpers::term_err, ok_arc, sockaddr_to_erl, + try_reply_async, }; pub struct UdpSocket { @@ -13,64 +15,59 @@ pub struct UdpSocket { #[rustler::resource_impl] impl rustler::Resource for UdpSocket {} -#[rustler::nif(schedule = "DirtyIo")] -fn udp_bind(env: rustler::Env, dev: ResourceArc, ip: Term, port: u16) -> impl Encoder { +#[rustler::nif] +fn udp_bind<'e>( + env: rustler::Env<'e>, + dev: ResourceArc, + ip: IpOrSelf, + port: u16, +) -> NifResult> { let dev = dev.inner.clone(); - let ip = IpOrSelf::new(ip); - let sock = TOKIO_RUNTIME.block_on(async move { - let addr = ip.ok_or("invalid ip addr")?.resolve(&dev).await?; - let sock = dev.udp_bind((addr, port).into()).await?; + try_reply_async(env, async move { + let addr = ip.resolve(&dev).await?; + let sock = dev.udp_bind((addr, port).into()).await.map_err(term_err)?; ok_arc(UdpSocket { inner: Arc::new(sock), }) - }); - - erl_result(env, sock) + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn udp_send<'env>( - env: rustler::Env<'env>, +#[rustler::nif] +fn udp_send<'e>( + env: rustler::Env<'e>, sock: ResourceArc, - ip: Term, + addr: ErlIp, port: u16, msg: Binary, -) -> Term<'env> { - let addr = ip_from_erl(ip); +) -> NifResult> { let msg = msg.to_vec(); let sock = sock.inner.clone(); - match TOKIO_RUNTIME.block_on(async move { - let addr = addr.ok_or("invalid ip addr")?; - - sock.send_to((addr, port).into(), &msg).await?; - - Result::<_>::Ok(()) - }) { - Ok(_) => atoms::ok().encode(env), - Err(e) => (atoms::error(), e.to_string()).encode(env), - } + try_reply_async(env, async move { + sock.send_to((addr, port).into(), &msg) + .await + .map(|_| ()) + .map_err(term_err) + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn udp_recv(env: rustler::Env, sock: ResourceArc) -> Term { - let (who, msg) = match sock.inner.recv_from_bytes_blocking() { - Ok((who, msg)) => (who, msg), - Err(e) => return erl_result(env, Result::<()>::Err(e.into())), - }; +#[rustler::nif] +fn udp_recv(env: rustler::Env, sock: ResourceArc) -> AsyncReply { + let sock = sock.inner.clone(); - ( - atoms::ok(), - ip_to_erl(env, who.ip()), - who.port(), - msg.to_vec(), - ) - .encode(env) + try_reply_async(env, async move { + sock.recv_from_bytes() + .await + .map(|(s, msg)| (ErlIp(s.ip()), s.port(), msg.to_vec())) + .map_err(term_err) + }) } #[rustler::nif] -fn udp_local_addr(env: rustler::Env, sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(env, sock.inner.local_addr()) +fn udp_local_addr(sock: ResourceArc) -> impl Encoder { + sockaddr_to_erl(sock.inner.local_addr()) } diff --git a/ts_elixir/test/async_callback_test.exs b/ts_elixir/test/async_callback_test.exs new file mode 100644 index 00000000..d4b8d844 --- /dev/null +++ b/ts_elixir/test/async_callback_test.exs @@ -0,0 +1,69 @@ +defmodule Tailscale.Test.AsyncCallbacks do + use ExUnit.Case, async: true + require Tailscale.Util + alias Tailscale.Native + + defmacrop await(block, local) do + if local do + quote do + Tailscale.Util.await_local(unquote(block)) + end + else + quote do + Tailscale.Util.await(unquote(block)) + end + end + end + + for local <- [true, false] do + describe "async calls (local: #{local})" do + for msg <- ["msg", nil] do + test "panic (msg: #{msg})" do + result = await(Native.async_panic(unquote(msg)), local) + + {:error, {:nif_panic, arg}} = result + + if unquote(msg) != nil do + assert(arg == "msg") + end + end + end + + for atom <- [true, false] do + test "error (atom: #{atom})" do + assert( + await(Native.async_error("msg", unquote(atom)), local) == + {:error, + if unquote(atom) do + :msg + else + "msg" + end} + ) + end + + test("raise (atom: #{atom})") do + assert_raise RuntimeError, fn -> + msg = + if unquote(atom) do + "Elixir.RuntimeError" + else + "msg" + end + + await( + Native.async_raise(msg, unquote(atom)), + local + ) + end + end + end + + test "badarg" do + assert_raise ArgumentError, fn -> + await(Native.async_badarg(), local) + end + end + end + end +end