Skip to content

Commit

Permalink
Adding test for metering
Browse files Browse the repository at this point in the history
  • Loading branch information
Southporter committed Jun 11, 2024
1 parent 8ef7271 commit ed73d8c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 14 deletions.
36 changes: 36 additions & 0 deletions src/tests.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ const core = @import("core.zig");
const Limits = core.Limits;
const MemoryInstance = core.MemoryInstance;

const metering = @import("metering.zig");

test "StackVM.Integration" {
const wasm_filepath = "zig-out/bin/mandelbrot.wasm";

Expand All @@ -27,6 +29,40 @@ test "StackVM.Integration" {
defer module_inst.destroy();
}

test "StackVM.Metering" {
const wasm_filepath = "zig-out/bin/fibonacci.wasm";

var allocator = std.testing.allocator;

var cwd = std.fs.cwd();
const wasm_data: []u8 = try cwd.readFileAlloc(allocator, wasm_filepath, 1024 * 1024 * 128);
defer allocator.free(wasm_data);

const module_def_opts = core.ModuleDefinitionOpts{
.debug_name = std.fs.path.basename(wasm_filepath),
};
var module_def = try core.createModuleDefinition(allocator, module_def_opts);
defer module_def.destroy();

try module_def.decode(wasm_data);

var module_inst = try core.createModuleInstance(.Stack, module_def, allocator);
defer module_inst.destroy();

try module_inst.instantiate(.{});

var returns = [1]core.Val{.{ .I64 = 5555 }};
var params = [1]core.Val{.{ .I32 = 10 }};

const handle = try module_inst.getFunctionHandle("run");
const res = module_inst.invoke(handle, &params, &returns, .{
.meter = 2,
});
try std.testing.expectError(metering.MeteringTrapError.TrapMeterExceeded, res);
try module_inst.resumeInvoke(&returns, .{ .meter = 10000 });
try std.testing.expectEqual(89, returns[0].I32);
}

test "MemoryInstance.init" {
{
const limits = Limits{
Expand Down
31 changes: 17 additions & 14 deletions src/vm_stack.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1581,13 +1581,12 @@ const InstructionFuncs = struct {
const root_stackvm: *StackVM = StackVM.fromVM(root_module_instance.vm);

if (metering.enabled) {
var state = root_stackvm.meter_state;
if (root_stackvm.meter_state.enabled) {
const meter = metering.reduce(state.meter, code[pc]);
state.meter = meter;
const meter = metering.reduce(root_stackvm.meter_state.meter, code[pc]);
root_stackvm.meter_state.meter = meter;
if (meter == 0) {
state.pc = pc;
state.opcode = code[pc].opcode;
root_stackvm.meter_state.pc = pc;
root_stackvm.meter_state.opcode = code[pc].opcode;
return metering.MeteringTrapError.TrapMeterExceeded;
}
}
Expand Down Expand Up @@ -5227,7 +5226,7 @@ pub const StackVM = struct {

const MeterState = if (metering.enabled) struct {
pc: u32 = 0,
opcode: Opcode,
opcode: Opcode = Opcode.Invalid,
meter: metering.Meter,
enabled: bool = false,

Expand Down Expand Up @@ -5313,8 +5312,11 @@ pub const StackVM = struct {
}
if (metering.enabled) {
if (opts.meter != metering.initial_meter) {
self.meter_state.enabled = true;
self.meter_state.meter = opts.meter;
self.meter_state = .{
.enabled = true,
.meter = opts.meter,
.opcode = Opcode.Invalid,
};
}
}

Expand All @@ -5340,27 +5342,28 @@ pub const StackVM = struct {
var self: *StackVM = fromVM(vm);

var pc: u32 = 0;
const opcode: Opcode = if (self.debug_state) |debug_state| blk: {
var opcode = Opcode.Invalid;
if (self.debug_state) |debug_state| {
std.debug.assert(debug_state.is_invoking);
pc = debug_state.pc;
for (debug_state.trapped_opcodes.items) |op| {
if (op.address == debug_state.pc) {
break :blk op.opcode;
opcode = op.opcode;
break;
}
}
unreachable; // Should never get into a state where a trapped opcode doesn't have an associated record

} else blk: {
} else {
if (metering.enabled) {
std.debug.assert(self.meter_state.enabled);
pc = self.meter_state.pc;
if (opts.meter != metering.initial_meter) {
self.meter_state.meter = opts.meter;
}
break :blk self.meter_state.opcode;
opcode = self.meter_state.opcode;
}
unreachable;
};
}

const op_func = InstructionFuncs.lookup(opcode);
try op_func(pc, module.module_def.code.instructions.items.ptr, &self.stack);
Expand Down

0 comments on commit ed73d8c

Please sign in to comment.