diff --git a/src/definition.zig b/src/definition.zig index d77d48e..fbc9a92 100644 --- a/src/definition.zig +++ b/src/definition.zig @@ -349,28 +349,27 @@ pub const Limits = struct { } }; -const BlockType = enum { +const BlockType = enum(u8) { Void, ValType, TypeIndex, }; -pub const BlockTypeValue = union(BlockType) { - Void: void, +pub const BlockTypeValue = extern union { ValType: ValType, TypeIndex: u32, - fn getBlocktypeParamTypes(blocktype: BlockTypeValue, module_def: *const ModuleDefinition) []const ValType { - switch (blocktype) { + fn getBlocktypeParamTypes(value: BlockTypeValue, block_type: BlockType, module_def: *const ModuleDefinition) []const ValType { + switch (block_type) { else => return &BlockTypeStatics.empty, - .TypeIndex => |index| return module_def.types.items[index].getParams(), + .TypeIndex => return module_def.types.items[value.TypeIndex].getParams(), } } - fn getBlocktypeReturnTypes(blocktype: BlockTypeValue, module_def: *const ModuleDefinition) []const ValType { - switch (blocktype) { + fn getBlocktypeReturnTypes(value: BlockTypeValue, block_type: BlockType, module_def: *const ModuleDefinition) []const ValType { + switch (block_type) { .Void => return &BlockTypeStatics.empty, - .ValType => |v| return switch (v) { + .ValType => return switch (value.ValType) { .I32 => &BlockTypeStatics.valtype_i32, .I64 => &BlockTypeStatics.valtype_i64, .F32 => &BlockTypeStatics.valtype_f32, @@ -379,21 +378,8 @@ pub const BlockTypeValue = union(BlockType) { .FuncRef => &BlockTypeStatics.reftype_funcref, .ExternRef => &BlockTypeStatics.reftype_externref, }, - .TypeIndex => |index| return module_def.types.items[index].getReturns(), - } - } - - fn toU64(value: BlockTypeValue) u64 { - comptime { - std.debug.assert(@sizeOf(BlockTypeValue) == @sizeOf(u64)); + .TypeIndex => return module_def.types.items[value.TypeIndex].getReturns(), } - const value_ptr: *const BlockTypeValue = &value; - return @as(*const u64, @ptrCast(value_ptr)).*; - } - - fn fromU64(value: u64) BlockTypeValue { - const value_ptr: *const u64 = &value; - return @as(*const BlockTypeValue, @ptrCast(value_ptr)).*; } }; @@ -595,10 +581,10 @@ pub const FunctionTypeDefinition = struct { }; pub const FunctionDefinition = struct { - type_index: u32, - instructions_begin: u32, - instructions_end: u32, - continuation: u32, + type_index: usize, + instructions_begin: usize, + instructions_end: usize, + continuation: usize, locals: std.ArrayList(ValType), // TODO use a slice of a large contiguous array instead pub fn instructions(func: FunctionDefinition, module_def: ModuleDefinition) []Instruction { @@ -793,59 +779,71 @@ const MemArg = struct { } }; -pub const MemoryOffsetAndLaneImmediates = struct { +pub const MemoryOffsetAndLaneImmediates = extern struct { offset: u64, laneidx: u8, }; -pub const CallIndirectImmediates = struct { +pub const CallIndirectImmediates = extern struct { type_index: u32, table_index: u32, }; -pub const BranchTableImmediates = struct { - label_ids: std.ArrayList(u32), // TODO optimize to make less allocations +pub const BranchTableImmediates = extern struct { + label_ids_begin: u32, + label_ids_end: u32, fallback_id: u32, + + pub fn getLabelIds(self: BranchTableImmediates, module: ModuleDefinition) []const u32 { + return module.code.branch_table_ids.items[self.label_ids_begin..self.label_ids_end]; + } }; -pub const TablePairImmediates = struct { +pub const TablePairImmediates = extern struct { index_x: u32, index_y: u32, }; -pub const BlockImmediates = struct { - blocktype: BlockTypeValue, +pub const BlockImmediates = extern struct { + block_type: BlockType, + block_value: BlockTypeValue, num_returns: u32, continuation: u32, }; -pub const IfImmediates = struct { - blocktype: BlockTypeValue, +pub const IfImmediates = extern struct { + block_type: BlockType, + block_value: BlockTypeValue, num_returns: u32, else_continuation: u32, end_continuation: u32, }; -const InstructionImmediatesTypes = enum(u8) { - Void, - ValType, - ValueI32, - ValueF32, - ValueI64, - ValueF64, - ValueVec, - Index, - LabelId, - MemoryOffset, - MemoryOffsetAndLane, - Block, - CallIndirect, - TablePair, - If, - VecShuffle16, +// const InstructionImmediatesTypes = enum(u8) { +// Void, +// ValType, +// ValueI32, +// ValueF32, +// ValueI64, +// ValueF64, +// ValueVec, +// Index, +// LabelId, +// MemoryOffset, +// MemoryOffsetAndLane, +// Block, +// CallIndirect, +// TablePair, +// If, +// VecShuffle16, +// }; + +pub const AlignedBytes = struct { + bytes: []align(1) const u8, + alignment: usize, }; -pub const InstructionImmediates = union(InstructionImmediatesTypes) { +pub const InstructionImmediates = extern union { Void: void, ValType: ValType, ValueI32: i32, @@ -871,13 +869,15 @@ pub const Instruction = struct { fn decode(reader: anytype, module: *ModuleDefinition) !Instruction { const Helpers = struct { fn decodeBlockType(_reader: anytype, _module: *const ModuleDefinition) !InstructionImmediates { - var blocktype: BlockTypeValue = undefined; + var block_type: BlockType = undefined; + var block_value: BlockTypeValue = undefined; const blocktype_raw = try _reader.readByte(); const valtype_or_err = ValType.bytecodeToValtype(blocktype_raw); if (std.meta.isError(valtype_or_err)) { if (blocktype_raw == k_block_type_void_sentinel_byte) { - blocktype = BlockTypeValue{ .Void = {} }; + block_type = .Void; + block_value = BlockTypeValue{ .TypeIndex = 0 }; } else { _reader.context.pos -= 1; // move the stream backwards 1 byte to reconstruct the integer const index_33bit = try common.decodeLEB128(i33, _reader); @@ -886,21 +886,24 @@ pub const Instruction = struct { } const index: u32 = @as(u32, @intCast(index_33bit)); if (index < _module.types.items.len) { - blocktype = BlockTypeValue{ .TypeIndex = index }; + block_type = .TypeIndex; + block_value = BlockTypeValue{ .TypeIndex = index }; } else { return error.ValidationUnknownBlockTypeIndex; } } } else { const valtype: ValType = valtype_or_err catch unreachable; - blocktype = BlockTypeValue{ .ValType = valtype }; + block_type = .ValType; + block_value = BlockTypeValue{ .ValType = valtype }; } - const num_returns: u32 = @as(u32, @intCast(blocktype.getBlocktypeReturnTypes(_module).len)); + const num_returns: u32 = @as(u32, @intCast(block_value.getBlocktypeReturnTypes(block_type, _module).len)); return InstructionImmediates{ .Block = BlockImmediates{ - .blocktype = blocktype, + .block_type = block_type, + .block_value = block_value, .num_returns = num_returns, .continuation = std.math.maxInt(u32), // will be set later in the code decode }, @@ -984,7 +987,8 @@ pub const Instruction = struct { const block_immediates: InstructionImmediates = try Helpers.decodeBlockType(reader, module); immediate = InstructionImmediates{ .If = IfImmediates{ - .blocktype = block_immediates.Block.blocktype, + .block_type = block_immediates.Block.block_type, + .block_value = block_immediates.Block.block_value, .num_returns = block_immediates.Block.num_returns, .else_continuation = block_immediates.Block.continuation, .end_continuation = block_immediates.Block.continuation, @@ -1002,6 +1006,7 @@ pub const Instruction = struct { const table_length = try common.decodeLEB128(u32, reader); var label_ids = std.ArrayList(u32).init(module.allocator); + defer label_ids.deinit(); try label_ids.ensureTotalCapacity(table_length); var index: u32 = 0; @@ -1011,26 +1016,33 @@ pub const Instruction = struct { } const fallback_id = try common.decodeLEB128(u32, reader); - var branch_table = BranchTableImmediates{ - .label_ids = label_ids, - .fallback_id = fallback_id, - }; - + // check to see if there are any existing tables we can reuse + var needs_immediate: bool = true; for (module.code.branch_table.items, 0..) |*item, i| { - if (item.fallback_id == branch_table.fallback_id) { - if (std.mem.eql(u32, item.label_ids.items, branch_table.label_ids.items)) { + if (item.fallback_id == fallback_id) { + const item_label_ids: []const u32 = item.getLabelIds(module.*); + if (std.mem.eql(u32, item_label_ids, label_ids.items)) { immediate = InstructionImmediates{ .Index = @as(u32, @intCast(i)) }; + needs_immediate = false; break; } } } - if (std.meta.activeTag(immediate) == .Void) { + if (needs_immediate) { immediate = InstructionImmediates{ .Index = @as(u32, @intCast(module.code.branch_table.items.len)) }; + + const label_ids_begin: u32 = @intCast(module.code.branch_table_ids.items.len); + try module.code.branch_table_ids.appendSlice(label_ids.items); + const label_ids_end: u32 = @intCast(module.code.branch_table_ids.items.len); + + const branch_table = BranchTableImmediates{ + .label_ids_begin = label_ids_begin, + .label_ids_end = label_ids_end, + .fallback_id = fallback_id, + }; + try module.code.branch_table.append(branch_table); - } else { - // don't need this anymore since we're reusing the existing one - branch_table.label_ids.deinit(); } }, .Call => { @@ -1308,7 +1320,7 @@ const CustomSection = struct { pub const NameCustomSection = struct { const NameAssoc = struct { name: []const u8, - func_index: u32, + func_index: usize, fn cmp(_: void, a: NameAssoc, b: NameAssoc) bool { return a.func_index < b.func_index; @@ -1410,7 +1422,7 @@ pub const NameCustomSection = struct { return self.module_name; } - pub fn findFunctionName(self: *const NameCustomSection, func_index: u32) []const u8 { + pub fn findFunctionName(self: *const NameCustomSection, func_index: usize) []const u8 { if (func_index < self.function_names.items.len) { if (self.function_names.items[func_index].func_index == func_index) { return self.function_names.items[func_index].name; @@ -1551,14 +1563,19 @@ const ModuleValidator = struct { } fn enterBlock(validator: *ModuleValidator, module_: *const ModuleDefinition, instruction_: Instruction) !void { - const blocktype: BlockTypeValue = switch (instruction_.immediate) { - .Block => |v| v.blocktype, - .If => |v| v.blocktype, + const block_type: BlockType = switch (instruction_.opcode) { + .Block, .Loop => instruction_.immediate.Block.block_type, + .If => instruction_.immediate.If.block_type, + else => unreachable, + }; + const block_value: BlockTypeValue = switch (instruction_.opcode) { + .Block, .Loop => instruction_.immediate.Block.block_value, + .If => instruction_.immediate.If.block_value, else => unreachable, }; - const start_types: []const ValType = blocktype.getBlocktypeParamTypes(module_); - const end_types: []const ValType = blocktype.getBlocktypeReturnTypes(module_); + const start_types: []const ValType = block_value.getBlocktypeParamTypes(block_type, module_); + const end_types: []const ValType = block_value.getBlocktypeReturnTypes(block_type, module_); try popReturnTypes(validator, start_types); try validator.pushControl(instruction_.opcode, start_types, end_types); @@ -1700,7 +1717,7 @@ const ModuleValidator = struct { frame.is_unreachable = true; } - fn popPushFuncTypes(validator: *ModuleValidator, type_index: u32, module_: *const ModuleDefinition) !void { + fn popPushFuncTypes(validator: *ModuleValidator, type_index: usize, module_: *const ModuleDefinition) !void { const func_type: *const FunctionTypeDefinition = &module_.types.items[type_index]; try popReturnTypes(validator, func_type.getParams()); @@ -1769,12 +1786,13 @@ const ModuleValidator = struct { }, .Branch_Table => { const immediates: *const BranchTableImmediates = &module.code.branch_table.items[instruction.immediate.Index]; + const label_ids: []const u32 = immediates.getLabelIds(module.*); const fallback_block_return_types: []const ValType = try Helpers.getControlTypes(self, immediates.fallback_id); try self.popType(.I32); - for (immediates.label_ids.items) |control_index| { + for (label_ids) |control_index| { const block_return_types: []const ValType = try Helpers.getControlTypes(self, control_index); if (fallback_block_return_types.len != block_return_types.len) { @@ -1811,7 +1829,7 @@ const ModuleValidator = struct { return error.ValidationUnknownFunction; } - const type_index: u32 = module.getFuncTypeIndex(func_index); + const type_index: usize = module.getFuncTypeIndex(func_index); try Helpers.popPushFuncTypes(self, type_index, module); }, .Call_Indirect => { @@ -2639,6 +2657,7 @@ pub const ModuleDefinition = struct { // Instruction.immediate indexes these arrays depending on the opcode branch_table: std.ArrayList(BranchTableImmediates), + branch_table_ids: std.ArrayList(u32), }; const Imports = struct { @@ -2687,6 +2706,7 @@ pub const ModuleDefinition = struct { .instructions = std.ArrayList(Instruction).init(allocator), .wasm_address_to_instruction_index = std.AutoHashMap(u32, u32).init(allocator), .branch_table = std.ArrayList(BranchTableImmediates).init(allocator), + .branch_table_ids = std.ArrayList(u32).init(allocator), }, .types = std.ArrayList(FunctionTypeDefinition).init(allocator), .imports = Imports{ @@ -3260,7 +3280,7 @@ pub const ModuleDefinition = struct { func_def.instructions_begin = @intCast(instructions.items.len); try block_stack.append(BlockData{ - .begin_index = func_def.instructions_begin, + .begin_index = @intCast(func_def.instructions_begin), .opcode = .Block, }); @@ -3302,11 +3322,11 @@ pub const ModuleDefinition = struct { if (block.opcode == .Loop) { block_instruction.immediate.Block.continuation = block.begin_index; } else { - switch (block_instruction.immediate) { - .Block => |*v| v.continuation = instruction_index, - .If => |*v| { - v.end_continuation = instruction_index; - v.else_continuation = instruction_index; + switch (block_instruction.opcode) { + .Block => block_instruction.immediate.Block.continuation = instruction_index, + .If => { + block_instruction.immediate.If.end_continuation = instruction_index; + block_instruction.immediate.If.else_continuation = instruction_index; }, else => unreachable, } @@ -3395,10 +3415,8 @@ pub const ModuleDefinition = struct { pub fn destroy(self: *ModuleDefinition) void { self.code.instructions.deinit(); self.code.wasm_address_to_instruction_index.deinit(); - for (self.code.branch_table.items) |*item| { - item.label_ids.deinit(); - } self.code.branch_table.deinit(); + self.code.branch_table_ids.deinit(); for (self.imports.functions.items) |*item| { self.allocator.free(item.names.module_name); @@ -3641,7 +3659,7 @@ pub const ModuleDefinition = struct { } } - fn getFuncTypeIndex(self: *const ModuleDefinition, func_index: usize) u32 { + fn getFuncTypeIndex(self: *const ModuleDefinition, func_index: usize) usize { if (func_index < self.imports.functions.items.len) { const func_def: *const FunctionImportDefinition = &self.imports.functions.items[func_index]; return func_def.type_index; diff --git a/src/instance.zig b/src/instance.zig index 4186fda..0b7e2f3 100644 --- a/src/instance.zig +++ b/src/instance.zig @@ -107,9 +107,9 @@ pub const DebugTrace = struct { } } - pub fn traceFunction(module_instance: *const ModuleInstance, indent: u32, func_index: u32) void { + pub fn traceFunction(module_instance: *const ModuleInstance, indent: u32, func_index: usize) void { if (shouldTraceFunctions()) { - const func_name_index: u32 = func_index + @as(u32, @intCast(module_instance.module_def.imports.functions.items.len)); + const func_name_index: usize = func_index + module_instance.module_def.imports.functions.items.len; const name_section: *const NameCustomSection = &module_instance.module_def.name_section; const module_name = name_section.getModuleName(); diff --git a/src/vm_stack.zig b/src/vm_stack.zig index ad1715b..ef2ca53 100644 --- a/src/vm_stack.zig +++ b/src/vm_stack.zig @@ -84,9 +84,9 @@ const DebugTraceStackVM = struct { }; const FunctionInstance = struct { - type_def_index: u32, - def_index: u32, - instructions_begin: u32, + type_def_index: usize, + def_index: usize, + instructions_begin: usize, local_types: std.ArrayList(ValType), }; @@ -1061,7 +1061,7 @@ const InstructionFuncs = struct { return FuncCallData{ .code = module_instance.module_def.code.instructions.items.ptr, - .continuation = func.instructions_begin, + .continuation = @intCast(func.instructions_begin), }; } @@ -1721,8 +1721,12 @@ const InstructionFuncs = struct { fn op_Branch_Table(pc: u32, code: [*]const Instruction, stack: *Stack) anyerror!void { try debugPreamble("Branch_Table", pc, code, stack); - const immediates: *const BranchTableImmediates = &stack.topFrame().module_instance.module_def.code.branch_table.items[code[pc].immediate.Index]; - const table: []const u32 = immediates.label_ids.items; + const module_instance: *const ModuleInstance = stack.topFrame().module_instance; + const all_branch_table_immediates: []const BranchTableImmediates = stack.topFrame().module_instance.module_def.code.branch_table.items; + const immediate_index = code[pc].immediate.Index; + + const immediates: BranchTableImmediates = all_branch_table_immediates[immediate_index]; + const table: []const u32 = immediates.getLabelIds(module_instance.module_def.*); const label_index = stack.popI32(); const label_id: u32 = if (label_index >= 0 and label_index < table.len) table[@as(usize, @intCast(label_index))] else immediates.fallback_id; @@ -5399,7 +5403,7 @@ pub const StackVM = struct { const name_section: *const NameCustomSection = &frame.module_instance.module_def.name_section; const module_name = name_section.getModuleName(); - const func_name_index: u32 = frame.func.def_index + @as(u32, @intCast(frame.module_instance.module_def.imports.functions.items.len)); + const func_name_index: usize = frame.func.def_index + frame.module_instance.module_def.imports.functions.items.len; const function_name = name_section.findFunctionName(func_name_index); try writer.print("{}: {s}!{s}\n", .{ reverse_index, module_name, function_name }); @@ -5436,11 +5440,11 @@ pub const StackVM = struct { } try self.stack.pushFrame(&func, module, param_types, func.local_types.items, func_type.calcNumReturns()); - try self.stack.pushLabel(@as(u32, @intCast(return_types.len)), func_def.continuation); + try self.stack.pushLabel(@as(u32, @intCast(return_types.len)), @intCast(func_def.continuation)); DebugTrace.traceFunction(module, self.stack.num_frames, func.def_index); - try InstructionFuncs.run(func.instructions_begin, module.module_def.code.instructions.items.ptr, &self.stack); + try InstructionFuncs.run(@intCast(func.instructions_begin), module.module_def.code.instructions.items.ptr, &self.stack); if (returns_slice.len > 0) { var index: i32 = @as(i32, @intCast(returns_slice.len - 1));