Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metering to the VM #49

Merged
merged 6 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
- name: Run unit tests
run: |
zig build test-unit
zig build -Dmeter=true test-unit

- name: Run wasm testsuite
run: |
Expand Down
15 changes: 15 additions & 0 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ const ExeOpts = struct {
description: []const u8,
step_dependencies: ?[]*Build.Step = null,
should_emit_asm: bool = false,
options: *Build.Step.Options,
};

pub fn build(b: *Build) void {
const should_emit_asm = b.option(bool, "asm", "Emit asm for the bytebox binaries") orelse false;

const enable_metering = b.option(bool, "meter", "Enable metering") orelse false;

const options = b.addOptions();
options.addOption(bool, "enable_metering", enable_metering);

const target = b.standardTargetOptions(.{});
const optimize = b.standardOptimizeOption(.{});

Expand All @@ -37,6 +43,8 @@ pub fn build(b: *Build) void {
.imports = &[_]ModuleImport{stable_array_import},
});

bytebox_module.addOptions("config", options);

// exe.root_module.addImport(import.name, import.module);

const imports = [_]ModuleImport{
Expand All @@ -50,6 +58,7 @@ pub fn build(b: *Build) void {
.step_name = "run",
.description = "Run a wasm program",
.should_emit_asm = should_emit_asm,
.options = options,
});

var bench_steps = [_]*Build.Step{
Expand All @@ -63,6 +72,7 @@ pub fn build(b: *Build) void {
.step_name = "bench",
.description = "Run the benchmark suite",
.step_dependencies = &bench_steps,
.options = options,
});

const lib_bytebox: *Build.Step.Compile = b.addStaticLibrary(.{
Expand All @@ -72,6 +82,7 @@ pub fn build(b: *Build) void {
.optimize = optimize,
});
lib_bytebox.root_module.addImport(stable_array_import.name, stable_array_import.module);
lib_bytebox.root_module.addOptions("config", options);
lib_bytebox.installHeader(b.path("src/bytebox.h"), "bytebox.h");
b.installArtifact(lib_bytebox);

Expand All @@ -82,6 +93,7 @@ pub fn build(b: *Build) void {
.optimize = optimize,
});
unit_tests.root_module.addImport(stable_array_import.name, stable_array_import.module);
unit_tests.root_module.addOptions("config", options);
const run_unit_tests = b.addRunArtifact(unit_tests);
const unit_test_step = b.step("test-unit", "Run unit tests");
unit_test_step.dependOn(&run_unit_tests.step);
Expand All @@ -92,6 +104,7 @@ pub fn build(b: *Build) void {
.root_src = "test/wasm/main.zig",
.step_name = "test-wasm",
.description = "Run the wasm testsuite",
.options = options,
});

// wasi tests
Expand All @@ -109,6 +122,7 @@ pub fn build(b: *Build) void {
.root_src = "test/mem64/main.zig",
.step_name = "test-mem64",
.description = "Run the mem64 test",
.options = options,
});

// All tests
Expand All @@ -130,6 +144,7 @@ fn buildExeWithRunStep(b: *Build, target: Build.ResolvedTarget, optimize: std.bu
for (imports) |import| {
exe.root_module.addImport(import.name, import.module);
}
exe.root_module.addOptions("config", opts.options);

// exe.emit_asm = if (opts.should_emit_asm) .emit else .default;
b.installArtifact(exe);
Expand Down
2 changes: 1 addition & 1 deletion build.zig.zon
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.{
.name = "bytebox",
.version = "0.0.1",
.minimum_zig_version = "0.12.0",
.minimum_zig_version = "0.13.0",
.paths = .{
"src",
"test/mem64",
Expand Down
19 changes: 13 additions & 6 deletions src/instance.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ const AllocError = std.mem.Allocator.Error;

const builtin = @import("builtin");

const metering = @import("metering.zig");

const common = @import("common.zig");
const StableArray = common.StableArray;
const Logger = common.Logger;
Expand Down Expand Up @@ -46,6 +48,7 @@ pub const ExportError = error{

pub const TrapError = error{
TrapDebug,
TrapInvalidResume,
TrapUnreachable,
TrapIntegerDivisionByZero,
TrapIntegerOverflow,
Expand All @@ -57,7 +60,7 @@ pub const TrapError = error{
TrapOutOfBoundsTableAccess,
TrapStackExhausted,
TrapUnknown,
};
} || metering.MeteringTrapError;

pub const DebugTrace = struct {
pub const Mode = enum {
Expand Down Expand Up @@ -616,6 +619,10 @@ pub const ModuleInstantiateOpts = struct {

pub const InvokeOpts = struct {
trap_on_start: bool = false,
meter: metering.Meter = metering.initial_meter,
};
pub const ResumeInvokeOpts = struct {
meter: metering.Meter = metering.initial_meter,
};

pub const DebugTrapInstructionMode = enum {
Expand All @@ -629,7 +636,7 @@ pub const VM = struct {
const InstantiateFn = *const fn (vm: *VM, module: *ModuleInstance, opts: ModuleInstantiateOpts) anyerror!void;
const InvokeFn = *const fn (vm: *VM, module: *ModuleInstance, handle: FunctionHandle, params: [*]const Val, returns: [*]Val, opts: InvokeOpts) anyerror!void;
const InvokeWithIndexFn = *const fn (vm: *VM, module: *ModuleInstance, func_index: usize, params: [*]const Val, returns: [*]Val) anyerror!void;
const ResumeInvokeFn = *const fn (vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void;
const ResumeInvokeFn = *const fn (vm: *VM, module: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void;
const StepFn = *const fn (vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void;
const SetDebugTrapFn = *const fn (vm: *VM, module: *ModuleInstance, wasm_address: u32, mode: DebugTrapInstructionMode) anyerror!bool;
const FormatBacktraceFn = *const fn (vm: *VM, indent: u8, allocator: std.mem.Allocator) anyerror!std.ArrayList(u8);
Expand Down Expand Up @@ -699,8 +706,8 @@ pub const VM = struct {
try vm.invoke_with_index_fn(vm, module, func_index, params, returns);
}

pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void {
try vm.resume_invoke_fn(vm, module, returns);
pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void {
try vm.resume_invoke_fn(vm, module, returns, opts);
}

pub fn step(vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void {
Expand Down Expand Up @@ -1186,8 +1193,8 @@ pub const ModuleInstance = struct {
}

/// Use to resume an invoked function after it returned error.DebugTrap
pub fn resumeInvoke(self: *ModuleInstance, returns: []Val) anyerror!void {
try self.vm.resumeInvoke(self, returns);
pub fn resumeInvoke(self: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void {
try self.vm.resumeInvoke(self, returns, opts);
}

pub fn step(self: *ModuleInstance, returns: []Val) anyerror!void {
Expand Down
20 changes: 20 additions & 0 deletions src/metering.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
const config = @import("config");
const Instruction = @import("definition.zig").Instruction;

pub const enabled = config.enable_metering;

pub const Meter = if (enabled) usize else void;

pub const initial_meter = if (enabled) 0 else {};

pub const MeteringTrapError = if (enabled) error{TrapMeterExceeded} else error{};

pub fn reduce(fuel: Meter, instruction: Instruction) Meter {
if (fuel == 0) {
return fuel;
}
return switch (instruction.opcode) {
.Invalid, .Unreachable, .DebugTrap, .Noop, .Block, .Loop, .If, .IfNoElse, .Else, .End, .Branch, .Branch_If, .Branch_Table, .Drop => fuel,
else => fuel - 1,
};
}
45 changes: 45 additions & 0 deletions src/tests.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ const core = @import("core.zig");
const Limits = core.Limits;
const MemoryInstance = core.MemoryInstance;

const metering = @import("metering.zig");

test "StackVM.Integration" {
const wasm_filepath = "zig-out/bin/mandelbrot.wasm";

Expand All @@ -27,6 +29,49 @@ test "StackVM.Integration" {
defer module_inst.destroy();
}

test "StackVM.Metering" {
if (!metering.enabled) {
return;
}
const wasm_filepath = "zig-out/bin/fibonacci.wasm";

var allocator = std.testing.allocator;

var cwd = std.fs.cwd();
const wasm_data: []u8 = try cwd.readFileAlloc(allocator, wasm_filepath, 1024 * 1024 * 128);
defer allocator.free(wasm_data);

const module_def_opts = core.ModuleDefinitionOpts{
.debug_name = std.fs.path.basename(wasm_filepath),
};
var module_def = try core.createModuleDefinition(allocator, module_def_opts);
defer module_def.destroy();

try module_def.decode(wasm_data);

var module_inst = try core.createModuleInstance(.Stack, module_def, allocator);
defer module_inst.destroy();

try module_inst.instantiate(.{});

var returns = [1]core.Val{.{ .I64 = 5555 }};
var params = [1]core.Val{.{ .I32 = 10 }};

const handle = try module_inst.getFunctionHandle("run");
const res = module_inst.invoke(handle, &params, &returns, .{
.meter = 2,
});
try std.testing.expectError(metering.MeteringTrapError.TrapMeterExceeded, res);
try std.testing.expectEqual(5555, returns[0].I32);

const res2 = module_inst.resumeInvoke(&returns, .{ .meter = 5 });
try std.testing.expectError(metering.MeteringTrapError.TrapMeterExceeded, res2);
try std.testing.expectEqual(5555, returns[0].I32);

try module_inst.resumeInvoke(&returns, .{ .meter = 10000 });
try std.testing.expectEqual(89, returns[0].I32);
}

test "MemoryInstance.init" {
{
const limits = Limits{
Expand Down
4 changes: 3 additions & 1 deletion src/vm_register.zig
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const inst = @import("instance.zig");
const VM = inst.VM;
const ModuleInstance = inst.ModuleInstance;
const InvokeOpts = inst.InvokeOpts;
const ResumeInvokeOpts = inst.ResumeInvokeOpts;
const DebugTrapInstructionMode = inst.DebugTrapInstructionMode;
const ModuleInstantiateOpts = inst.ModuleInstantiateOpts;

Expand Down Expand Up @@ -1131,10 +1132,11 @@ pub const RegisterVM = struct {
return error.Unimplemented;
}

pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val) anyerror!void {
pub fn resumeInvoke(vm: *VM, module: *ModuleInstance, returns: []Val, opts: ResumeInvokeOpts) anyerror!void {
_ = vm;
_ = module;
_ = returns;
_ = opts;
return error.Unimplemented;
}

Expand Down
Loading
Loading