Skip to content

Commit

Permalink
Harden how ConstEval uses llvm-cpu and the runtime libraries. (iree-o…
Browse files Browse the repository at this point in the history
…rg#17075)

* Fixed iree-org#17070 by updating the
CMake options needed for ConstEval
* Replaced `IREE_CHECK_OK` usage with error handling
* Refactored `test/jit_globals.mlir`, adding coverage for llvm-cpu
(since that is actually running by default)
* I tried to keep all test cases in one file, but `--verify-diagnostics`
isn't compatible with that style of lit testing AFAICT
* This uncovered some bugs in iree-org#16321
/ missing support for i4 types
  • Loading branch information
ScottTodd authored Apr 22, 2024
1 parent d12291f commit f5660ee
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 93 deletions.
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,11 @@ option(IREE_HAL_EXECUTABLE_PLUGIN_EMBEDDED_ELF "Enables the embedded dynamic lib
option(IREE_HAL_EXECUTABLE_PLUGIN_SYSTEM_LIBRARY "Enables the system dynamic library plugin mechanism for local HAL drivers" ${IREE_HAL_EXECUTABLE_PLUGIN_SYSTEM_LIBRARY_DEFAULT})

if(IREE_BUILD_COMPILER)
# The compiler requires the local task driver with the VMVX loader.
# The compiler minimally requires the local task driver with the default
# (embedded elf) executable loader. This is used by the ConstEval component,
# which can also be used with VMVX or other loaders/devices. See issue#17070.
set(IREE_HAL_DRIVER_LOCAL_TASK ON)
set(IREE_HAL_EXECUTABLE_LOADER_VMVX_MODULE ON)
set(IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF ON)
endif()

message(STATUS "IREE HAL drivers:")
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/ConstEval/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ iree_compiler_cc_library(
deps = [
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/VM/Target/Bytecode",
"//runtime/src/iree/base",
"//runtime/src/iree/hal",
"//runtime/src/iree/hal/drivers/local_task/registration",
"//runtime/src/iree/modules/hal",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/ConstEval/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ iree_cc_library(
DEPS
LLVMSupport
MLIRIR
iree::base
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VM::Target::Bytecode
iree::hal
Expand Down
12 changes: 2 additions & 10 deletions compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,6 @@ struct JitGlobalsPass : public JitGlobalsBase<JitGlobalsPass> {
SupportedFeatures s;
Builder b(context);

// Exclude vmvx backend since there is no i4 support there causing
// the `eval_i4_tensor` test in `jit_globals.mlir` to fail.
// TODO(#16321): Enable on other backends once this has been tested
// outside llvm-cpu.
if (requestedTargetDevice == "llvm-cpu" && hasRequestedTargetDevice)
s.addScalarType(b.getIntegerType(4));
s.addScalarType(b.getIntegerType(8));
s.addScalarType(b.getIntegerType(16));
s.addScalarType(b.getIntegerType(32));
Expand All @@ -617,10 +611,6 @@ struct JitGlobalsPass : public JitGlobalsBase<JitGlobalsPass> {

s.addElementType(b.getIntegerType(1));

// TODO(#16321): Enable on other backends once this has been tested outside
// llvm-cpu.
if (requestedTargetDevice == "llvm-cpu" && hasRequestedTargetDevice)
s.addElementType(b.getIntegerType(4));
s.addElementType(b.getIntegerType(8));
s.addElementType(b.getIntegerType(16));
s.addElementType(b.getIntegerType(32));
Expand Down Expand Up @@ -655,6 +645,8 @@ struct JitGlobalsPass : public JitGlobalsBase<JitGlobalsPass> {

FunctionCall call(binary, jitFunction.argumentBindings.size(),
jitFunction.resultBindings.size());
if (failed(call.initialize(jitFunction.loc)))
return failure();

// Convert arguments.
for (ArgumentBinding &arg : jitFunction.argumentBindings) {
Expand Down
185 changes: 115 additions & 70 deletions compiler/src/iree/compiler/ConstEval/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/ConstEval/Runtime.h"

#include "iree/base/api.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "iree/hal/drivers/local_task/registration/driver_module.h"
Expand All @@ -19,20 +20,15 @@ namespace mlir::iree_compiler::ConstEval {

namespace {

LogicalResult handleRuntimeError(Location loc, iree_status_t status) {
LogicalResult handleRuntimeError(Location loc, iree_status_t status,
bool freeStatus = true) {
if (iree_status_is_ok(status))
return success();
std::string message;
message.resize(512);
iree_host_size_t buffer_length;
if (!iree_status_format(status, message.size(), &message[0],
&buffer_length)) {
message.resize(buffer_length + 1);
iree_status_format(status, message.size(), &message[0], &buffer_length);
std::string statusString = iree::Status::ToString(status);
if (freeStatus) {
iree_status_ignore(status);
}
message.resize(buffer_length);
iree_status_ignore(status);
return emitError(loc) << "runtime error in consteval: " << message;
return emitError(loc) << "runtime error in consteval: " << statusString;
}

LogicalResult convertToElementType(Location loc, Type baseType,
Expand Down Expand Up @@ -149,13 +145,21 @@ void CompiledBinary::deinitialize() {

FunctionCall::FunctionCall(CompiledBinary &binary, iree_host_size_t argCapacity,
iree_host_size_t resultCapacity)
: binary(binary) {
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(),
argCapacity, iree_allocator_system(),
&inputs));
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(),
resultCapacity, iree_allocator_system(),
&outputs));
: binary(binary), argCapacity(argCapacity), resultCapacity(resultCapacity) {
}

LogicalResult FunctionCall::initialize(Location loc) {
iree_status_t status = iree_ok_status();
if (iree_status_is_ok(status)) {
status = iree_vm_list_create(iree_vm_make_undefined_type_def(), argCapacity,
iree_allocator_system(), &inputs);
}
if (iree_status_is_ok(status)) {
status =
iree_vm_list_create(iree_vm_make_undefined_type_def(), resultCapacity,
iree_allocator_system(), &outputs);
}
return handleRuntimeError(loc, status);
}

FailureOr<iree::vm::ref<iree_hal_buffer_t>>
Expand All @@ -174,28 +178,36 @@ FunctionCall::importSerializableAttr(
std::memset(&params, 0, sizeof(params));
params.type =
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
if (failed(handleRuntimeError(
loc, iree_hal_allocator_allocate_buffer(binary.getAllocator(), params,
storageSize, &buffer))))
return failure();

iree_status_t status = iree_ok_status();
if (iree_status_is_ok(status)) {
status = iree_hal_allocator_allocate_buffer(binary.getAllocator(), params,
storageSize, &buffer);
}

iree_hal_buffer_mapping_t mapping;
if (failed(handleRuntimeError(
loc, iree_hal_buffer_map_range(
buffer.get(), IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_WRITE, /*byte_offset=*/0,
/*byte_length=*/storageSize, &mapping))))
return failure();
if (iree_status_is_ok(status)) {
status = iree_hal_buffer_map_range(
buffer.get(), IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_WRITE, /*byte_offset=*/0,
/*byte_length=*/storageSize, &mapping);
}

// Copy.
LogicalResult copyResult = serializableAttr.serializeToBuffer(
loc, llvm::endianness::native,
ArrayRef<char>(reinterpret_cast<char *>(mapping.contents.data),
storageSize));
if (iree_status_is_ok(status)) {
LogicalResult copyResult = serializableAttr.serializeToBuffer(
loc, llvm::endianness::native,
ArrayRef<char>(reinterpret_cast<char *>(mapping.contents.data),
storageSize));
iree_status_ignore(iree_hal_buffer_unmap_range(&mapping));
if (failed(copyResult)) {
status =
iree_make_status(IREE_STATUS_INTERNAL, "serializeToBuffer failed");
}
}

if (failed(handleRuntimeError(loc, iree_hal_buffer_unmap_range(&mapping))) ||
failed(copyResult)) {
return failure();
if (!iree_status_is_ok(status)) {
iree_hal_buffer_release(buffer.get());
return handleRuntimeError(loc, status);
}

return buffer;
Expand Down Expand Up @@ -404,15 +416,19 @@ TypedAttr CompiledBinary::convertVariantToAttribute(Location loc,
// mapping is not available. Today with the CPU backends it's always
// possible but would not work with accelerators.
iree_hal_buffer_mapping_t mapping;
IREE_CHECK_OK(iree_hal_buffer_map_range(
buffer, IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, length, &mapping));
if (failed(handleRuntimeError(
loc, iree_hal_buffer_map_range(
buffer, IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, length, &mapping)))) {
return {};
}
MutableArrayRef<char> rawBufferArray(
reinterpret_cast<char *>(mapping.contents.data),
mapping.contents.data_length);
auto convertedAttr =
createAttributeFromRawData(loc, tensorType, rawBufferArray);
iree_hal_buffer_unmap_range(&mapping);
iree_status_ignore(iree_hal_buffer_unmap_range(&mapping));
return convertedAttr;
} else {
iree_string_view_t typeName =
Expand All @@ -427,37 +443,58 @@ TypedAttr CompiledBinary::convertVariantToAttribute(Location loc,
return {};
}

void CompiledBinary::initialize(void *data, size_t length) {
LogicalResult CompiledBinary::initialize(Location loc, void *data,
size_t length) {
Runtime &runtime = Runtime::getInstance();
// Keep the sticky initStatus alive then free in |runtime|'s destructor.
if (failed(
handleRuntimeError(loc, runtime.initStatus, /*freeStatus=*/false))) {
return failure();
}

iree_status_t status = iree_ok_status();

// Create driver and device.
iree_hal_driver_t *driver = nullptr;
IREE_CHECK_OK(iree_hal_driver_registry_try_create(
runtime.registry, iree_make_cstring_view("local-task"),
iree_allocator_system(), &driver));
IREE_CHECK_OK(iree_hal_driver_create_default_device(
driver, iree_allocator_system(), &device));
if (iree_status_is_ok(status)) {
status = iree_hal_driver_registry_try_create(
runtime.registry, iree_make_cstring_view("local-task"),
iree_allocator_system(), &driver);
}

if (iree_status_is_ok(status)) {
status = iree_hal_driver_create_default_device(
driver, iree_allocator_system(), &device);
}
iree_hal_driver_release(driver);

// Create hal module.
iree_hal_device_t *device_ptr = device.get();
IREE_CHECK_OK(iree_hal_module_create(
runtime.instance.get(), /*device_count=*/1, &device_ptr,
IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &hal_module));
if (iree_status_is_ok(status)) {
std::array<iree_hal_device_t *, 1> devices = {device.get()};
status = iree_hal_module_create(runtime.instance.get(), devices.size(),
devices.data(), IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &hal_module);
}

// Bytecode module.
IREE_CHECK_OK(iree_vm_bytecode_module_create(
runtime.instance.get(), iree_make_const_byte_span(data, length),
iree_allocator_null(), iree_allocator_system(), &main_module));

// Context.
std::array<iree_vm_module_t *, 2> modules = {
hal_module.get(),
main_module.get(),
};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
runtime.instance.get(), IREE_VM_CONTEXT_FLAG_NONE, modules.size(),
modules.data(), iree_allocator_system(), &context));
if (iree_status_is_ok(status)) {
status = iree_vm_bytecode_module_create(
runtime.instance.get(), iree_make_const_byte_span(data, length),
iree_allocator_null(), iree_allocator_system(), &main_module);
}

// Create context.
if (iree_status_is_ok(status)) {
std::array<iree_vm_module_t *, 2> modules = {
hal_module.get(),
main_module.get(),
};
status = iree_vm_context_create_with_modules(
runtime.instance.get(), IREE_VM_CONTEXT_FLAG_NONE, modules.size(),
modules.data(), iree_allocator_system(), &context);
}

return handleRuntimeError(loc, status);
}

InMemoryCompiledBinary::~InMemoryCompiledBinary() { deinitialize(); }
Expand All @@ -472,20 +509,28 @@ InMemoryCompiledBinary::translateFromModule(mlir::ModuleOp moduleOp) {
return failure();
}
os.flush();
initialize(&binary[0], binary.length());
return success();
return initialize(moduleOp.getLoc(), &binary[0], binary.length());
}

Runtime::Runtime() {
IREE_CHECK_OK(
iree_hal_driver_registry_allocate(iree_allocator_system(), &registry));
IREE_CHECK_OK(iree_hal_local_task_driver_module_register(registry));
IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
iree_allocator_system(), &instance));
IREE_CHECK_OK(iree_hal_module_register_all_types(instance.get()));
if (iree_status_is_ok(initStatus)) {
initStatus =
iree_hal_driver_registry_allocate(iree_allocator_system(), &registry);
}
if (iree_status_is_ok(initStatus)) {
initStatus = iree_hal_local_task_driver_module_register(registry);
}
if (iree_status_is_ok(initStatus)) {
initStatus = iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
iree_allocator_system(), &instance);
}
if (iree_status_is_ok(initStatus)) {
initStatus = iree_hal_module_register_all_types(instance.get());
}
}

Runtime::~Runtime() {
iree_status_free(initStatus);
instance.reset();
iree_hal_driver_registry_free(registry);
}
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/iree/compiler/ConstEval/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CompiledBinary {

protected:
CompiledBinary();
void initialize(void *data, size_t length);
LogicalResult initialize(Location loc, void *data, size_t length);
// The base class does not clean up initialized state. This must be done
// explicitly by subclasses, ensuring that any backing images remain valid
// through the call to deinitialize().
Expand All @@ -55,6 +55,7 @@ class FunctionCall {
FunctionCall(CompiledBinary &binary, iree_host_size_t argCapacity,
iree_host_size_t resultCapacity);

LogicalResult initialize(Location loc);
LogicalResult addArgument(Location loc, Attribute attr);
LogicalResult invoke(Location loc, StringRef name);
LogicalResult getResultAsAttr(Location loc, size_t index, Type mlirType,
Expand All @@ -71,6 +72,8 @@ class FunctionCall {
IREE::Util::SerializableAttrInterface serializableAttr);

CompiledBinary binary;
iree_host_size_t argCapacity;
iree_host_size_t resultCapacity;
iree::vm::ref<iree_vm_list_t> inputs;
iree::vm::ref<iree_vm_list_t> outputs;
};
Expand All @@ -93,6 +96,7 @@ class Runtime {

iree_hal_driver_registry_t *registry = nullptr;
iree::vm::ref<iree_vm_instance_t> instance;
iree_status_t initStatus = iree_ok_status();

private:
Runtime();
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/ConstEval/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ iree_lit_test_suite(
"compile_regressions.mlir",
"failing.mlir",
"jit_globals.mlir",
"jit_globals_vmvx_errors.mlir",
"scalar_values.mlir",
],
include = ["*.mlir"],
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ iree_lit_test_suite(
"compile_regressions.mlir"
"failing.mlir"
"jit_globals.mlir"
"jit_globals_vmvx_errors.mlir"
"scalar_values.mlir"
TOOLS
FileCheck
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/ConstEval/test/failing.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --iree-consteval-jit-target-device=vmvx --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals %s | FileCheck %s
// RUN: iree-opt --split-input-file --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals %s | FileCheck %s
// XFAIL: *

// CHECK-LABEL: @eval_f64_scalar
Expand Down
Loading

0 comments on commit f5660ee

Please sign in to comment.