Skip to content

Add NIF for loading custom plugins #1519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 113 additions & 2 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <string>
#include <dlfcn.h>

#include "exla_client.h"
#include "exla_cuda.h"
Expand All @@ -11,11 +12,36 @@
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"
#include "xla/service/custom_call_target_registry.h"

// All of these are created with calls to `new` and subsequently
// passed to the VM as pointers-to-pointers so we balance it out
// with calls to delete rather than just using the default destructor.

// We need to hold a reference to the `dlopen` handle for as long
// as EXLA is running, so we have this resource which holds the handle,
// then we define a custom free which calls `dlclose`. Then it's up to
// the caller to keep this resource in scope so it's not garbage collected
typedef struct {
void * handle;
} ExlaPlugin;

typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims);

typedef struct {
const char* name;
ExlaCustomCallFunction func;
} ExlaPluginCustomCall;

static ErlNifResourceType* exla_plugin_resource_type;

void free_exla_plugin(ErlNifEnv* env, void* obj) {
ExlaPlugin* plugin = reinterpret_cast<ExlaPlugin*>(obj);
if (plugin != nullptr) {
dlclose(plugin->handle);
}
}

void free_exla_executable(ErlNifEnv* env, void* obj) {
exla::ExlaExecutable** executable = reinterpret_cast<exla::ExlaExecutable**>(obj);
if (*executable != nullptr) {
Expand Down Expand Up @@ -65,10 +91,17 @@ static int open_resources(ErlNifEnv* env) {
if (!exla::nif::open_resource<exla::MLIRModule*>(env, mod, "ExlaMLIRModule")) {
return -1;
}

if (!exla::nif::open_resource<mlir::MLIRContext*>(env, mod, "MLIRContext")) {
return -1;
}

// Just a C Resource
ErlNifResourceFlags flags = ErlNifResourceFlags(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER);
exla_plugin_resource_type = enif_open_resource_type(env, mod, "ExlaPlugin", free_exla_plugin, flags, NULL);
if (!exla_plugin_resource_type) {
return -1;
}

return 1;
}

Expand Down Expand Up @@ -911,6 +944,80 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
return exla::nif::ok(env);
}

// Plugins

ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
}

std::string library_path;

if (!exla::nif::get(env, argv[0], library_path)) {
return exla::nif::error(env, "Unable to get library path.");
}

void* handle = dlopen(library_path.c_str(), RTLD_NOW);
if (!handle) {
return exla::nif::error(env, "Unable to open library.");
}

ExlaPlugin* plugin = (ExlaPlugin*) enif_alloc_resource(exla_plugin_resource_type, sizeof(ExlaPlugin));
plugin->handle = handle;

ERL_NIF_TERM result = enif_make_resource(env, plugin);
enif_release_resource(plugin);

return exla::nif::ok(env, result);
}

ERL_NIF_TERM register_custom_call_symbol(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
}

ExlaPlugin* plugin;
std::string symbol;
std::vector<std::vector<exla::int64>> dimensions;

if (!enif_get_resource(env, argv[0], exla_plugin_resource_type, (void **) &plugin)) {
return exla::nif::error(env, "Unable to get plugin.");
}
if (!exla::nif::get(env, argv[1], symbol)) {
return exla::nif::error(env, "Unable to get symbol.");
}
if (!exla::nif::get_list(env, argv[2], dimensions)) {
return exla::nif::error(env, "Unable to get dimensions.");
}

ExlaCustomCallFunction function = (ExlaCustomCallFunction) dlsym(plugin->handle, symbol.c_str());

if (!function) {
return exla::nif::error(env, "Could not find symbol.");
}

auto lambda = [&dimensions, function](void *in[], const void *out[]) {
std::vector<std::vector<int>> int_dims(dimensions.size());
for (size_t i = 0; i < dimensions.size(); ++i) {
int_dims[i].resize(dimensions[i].size());
std::transform(dimensions[i].begin(), dimensions[i].end(), int_dims[i].begin(),
[](exla::int64 x) { return static_cast<int>(x); });
}

std::vector<int*> dims_ptrs;
for (auto& d : int_dims) {
dims_ptrs.push_back(d.data());
}

function(in, out, dims_ptrs.data());
};

// TODO: GPU/Client flag
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol.c_str(), function);

return exla::nif::ok(env);
}

static ErlNifFunc exla_funcs[] = {
// MLIR Builder
{"mlir_new_context", 0, mlir_new_context},
Expand Down Expand Up @@ -947,6 +1054,10 @@ static ErlNifFunc exla_funcs[] = {
{"start_log_sink", 1, start_log_sink},
// Serialization
{"serialize_executable", 1, serialize_executable},
{"deserialize_executable", 2, deserialize_executable}};
{"deserialize_executable", 2, deserialize_executable},
// Plugins
{"load_custom_call_plugin_library", 1, load_custom_call_plugin_library},
{"register_custom_call_symbol", 3, register_custom_call_symbol}
};

ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, &upgrade, NULL);
19 changes: 19 additions & 0 deletions exla/c_src/exla/exla_nif_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,25 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<std::string>& var) {
return 1;
}

int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<std::vector<int64>>& var) {
unsigned int length;
if (!enif_get_list_length(env, list, &length)) {
return 0;
}
var.reserve(length);
ERL_NIF_TERM head, tail;

while (enif_get_list_cell(env, list, &head, &tail)) {
std::vector<int64> elem;
if (!get_list(env, head, elem)) {
return 0;
}
var.push_back(elem);
list = tail;
}
return 1;
}

int get_binary(ErlNifEnv* env, ERL_NIF_TERM term, ErlNifBinary* var) {
return enif_inspect_binary(env, term, var);
}
Expand Down
2 changes: 2 additions & 0 deletions exla/c_src/exla/exla_nif_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<std::string>& var);

int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<xla::Shape>& var);

int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<std::vector<int64>>& var);

template <typename T>
int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<T*>& var) {
unsigned int length;
Expand Down
1 change: 1 addition & 0 deletions exla/lib/exla/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ defmodule EXLA.Application do
name: EXLA.MLIR.ContextPool,
lazy: true},
EXLA.Client,
EXLA.Plugin,
EXLA.Defn.Lock,
EXLA.Defn.LockedCache,
{Task.Supervisor, name: EXLA.Defn.TaskSupervisor}
Expand Down
21 changes: 21 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,27 @@ defmodule EXLA.MLIR.Value do
{q, r}
end

def plugin_custom_call(registered_name, [%Value{function: func} | _] = args, result_typespec) do
operand_shapes =
Enum.map(args, fn %Value{function: ^func} = value ->
%{shape: op_shape} = get_typespec(value)
constant(func, Tuple.to_list(op_shape), Typespec.tensor({:s, 64}, {length(op_shape)}))
end)

operands =
args
|> Enum.zip_with(operand_shapes, fn val, shape -> [val, shape] end)
|> List.flatten()

# TODO: GPU
attributes = [
call_target_name: attr_string(registered_name),
backend_config: attr_string("Host")
]

op(func, "stablehlo.custom_call", operands, result_typespec, attributes: attributes)
end

def get_tuple_element(%Value{function: func} = operand, index, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [index: attr_i32(index)]
Expand Down
4 changes: 4 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,8 @@ defmodule EXLA.NIF do
def get_c_api_client(_device_type), do: :erlang.nif_error(:undef)

def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef)

def load_custom_call_plugin_library(_library_path), do: :erlang.nif_error(:undef)

def register_custom_call_symbol(_plugin, _symbol, _dimensions), do: :erlang.nif_error(:undef)
end
56 changes: 56 additions & 0 deletions exla/lib/exla/plugin.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defmodule EXLA.Plugin do
@moduledoc """
Plugin system for registering custom calls.
"""
use GenServer

# TODO: Register and lookup per client

def start_link(_opts) do
GenServer.start_link(__MODULE__, %{}, name: __MODULE__)
end

def register(key, library_path) do
GenServer.cast(__MODULE__, {:register, key, library_path})
end

def lookup(key) do
GenServer.call(__MODULE__, {:lookup, key})
end

def register_symbol(key, symbol, dimensions) do
if ref = lookup(key) do
EXLA.NIF.register_custom_call_symbol(ref, symbol, dimensions)
end
end

@impl true
def init(_opts) do
{:ok, %{}}
end

@impl true
def handle_cast({:register, key, library_path}, state) do
case state do
%{^key => _ref} ->
{:noreply, state}

%{} ->
ref =
library_path
|> EXLA.NIF.load_custom_call_plugin_library()
|> unwrap!()

{:noreply, Map.put(state, key, ref)}
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use a process instead such that, if someone does Application.stop(:exla) the process is shutdown as well as all plugins? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding how to store things, I would rather do an ETS than a process for storing the custom call registry if persistent term is not what we want, to keep close to the same read concurrency.

Another possible alternative would be a GenServer that manages the persistent term on terminate, it would clean up the persitent term state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this manager alternative, we'd have non-process functions that read first and write if needed to the persistent term, and the only purpose for the GenServer would be to ensure cleanup upon termination.

PS: Upon writing this I went reading and found https://hexdocs.pm/elixir/1.12/Application.html#c:prep_stop/1 which would serve this purpose nicely.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prep_stop also works but you would need to iterate all persistent term to find the keys relevant to us. ETS would be better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's something like EXLA.CustomCalls.cleanup/0 we could call, then that function will already know about all of the keys that should be cleaned up.

I see no issue with using ETS however, as the speed difference here only matters at defn compile time and not runtime

end

@impl true
def handle_call({:lookup, key}, _from, state) do
value = Map.get(state, key)
{:reply, value, state}
end

defp unwrap!({:ok, ref}), do: ref
defp unwrap!({:error, reason}), do: raise("#{reason}")
end
9 changes: 9 additions & 0 deletions exla/test/exla/plugin_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defmodule EXLA.PluginTest do
use ExUnit.Case

describe "register/1" do
test "registers a plugin" do
assert :ok = EXLA.Plugin.register(:custom_plugin, "test/support/c/libcustom_plugin.so")
end
end
end
22 changes: 22 additions & 0 deletions exla/test/support/c/custom_plugin.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <cstdint>
#include <stddef.h>

typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims);

typedef struct {
const char* name;
ExlaCustomCallFunction func;
} ExlaPluginCustomCall;

extern "C" void custom_increment(void *out[], const void *in[], int **dims) {
int64_t *operand = (int64_t *)in[0];
int64_t *dim_sizes = (int64_t *)dims[0];

int64_t *out_buffer = (int64_t *)out[0];

int64_t n = dim_sizes[0];

for (int64_t i = 0; i < n; i++) {
out_buffer[i] = operand[i] + 1;
}
}
Binary file added exla/test/support/c/libcustom_plugin.so
Binary file not shown.
Loading