From 157aabf5ef0d386aca58a2af9b84a0ff5fc92474 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 23 Apr 2025 02:01:27 +0800 Subject: [PATCH 01/40] try lms --- benchmarks/wasm/staged/push-drop.wat | 7 ++ src/main/scala/wasm/StagedMiniWasm.scala | 98 ++++++++++++++++++++ src/test/scala/genwasym/TestStagedEval.scala | 22 +++++ 3 files changed, 127 insertions(+) create mode 100644 benchmarks/wasm/staged/push-drop.wat create mode 100644 src/main/scala/wasm/StagedMiniWasm.scala create mode 100644 src/test/scala/genwasym/TestStagedEval.scala diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat new file mode 100644 index 00000000..eec1f816 --- /dev/null +++ b/benchmarks/wasm/staged/push-drop.wat @@ -0,0 +1,7 @@ +(module $push-drop + (func $real_main (type 1) (result i32) + i32.const 2 + i32.const 2 + drop + drop) + (start 0)) \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala new file mode 100644 index 00000000..26691b0e --- /dev/null +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -0,0 +1,98 @@ +package gensym.wasm.miniwasm + +import scala.collection.mutable.ArrayBuffer + +import lms.core.stub.Adapter +import lms.core.virtualize +import lms.core.stub.Base +import lms.core.Backend.{Block => LMSBlock} + +import gensym.wasm.ast._ +import gensym.wasm.ast.{Const => ConstInstr} + +case class StagedEvaluator(module: ModuleInstance) extends Base { + // reset and initialize the internal state of Adapter + Adapter.resetState + Adapter.g = Adapter.mkGraphBuilder + + type Stack = Rep[List[Value]] + type Cont[A] = Stack => Rep[A] + type Trail[A] = List[Cont[A]] + + // Ans should be instantiated to something like Int, Unit, etc, which is the result type of staged program + def eval[Ans](insts: List[Instr], + stack: Stack, + frame: Rep[Frame], + kont: Cont[Ans], + trail: Trail[Ans]): Rep[Ans] = { + if (insts.isEmpty) return kont(stack) + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => eval(rest, stack.tail, frame, kont, trail) + // Why this cons operation compiled? does anything could be casted to Rep? + case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) + case _ => "todo-op".reflectWith() + } + } + + def evalTop[Ans](kont: Cont[Ans], main: Option[String]): Rep[Ans] = { + val funBody: FuncBodyDef = main match { + case Some(func_name) => + module.defs.flatMap({ + case Export(`func_name`, ExportFunc(fid)) => + println(s"Entering function $main") + module.funcs(fid) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => Some(body) + case _ => throw new Exception("Entry function has no concrete body") + } + case _ => None + }).head + case None => + val startIds = module.defs.flatMap { + case Start(id) => Some(id) + case _ => None + } + val startId = startIds.headOption.getOrElse { throw new Exception("No start function") } + module.funcs(startId) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => body + case _ => + throw new Exception("Entry function has no concrete body") + } + } + val (instrs, localSize) = (funBody.body, funBody.locals.size) + val frame = Frame(ArrayBuffer.fill(localSize)(I32V(0))) + eval(instrs, List(), frame, kont, List(kont)) + } + + def evalTop(main: Option[String]): Rep[Unit] = { + val haltK: Cont[Unit] = stack => { + "no-op".reflectWith() + } + evalTop(haltK, main) + } + + def codegen(main: Option[String]): LMSBlock = { + Adapter.g.reify( { Unwrap(evalTop(main)) } ) + } + + // The stack should be allocated on the stack to get optimal performance + implicit class StackOps(stack: Stack) { + def tail(): Stack = { + "value-stack-tail".reflectWith(stack) + } + + def ::[A](v: Rep[A]): Stack = { + "value-stack-cons".reflectWith(v, stack) + } + } + + // directly specify the translated operation + implicit class StringOps(op: String) { + def reflectWith[T: Manifest](rs: Rep[_]*): Rep[T] = { + val result = rs.map(Unwrap) + Predef.println(s"reflectWith: $op, $result") + val result1 = Adapter.g.reflect(op, result:_*) + Wrap[T](result1) + } + } +} diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala new file mode 100644 index 00000000..39e5c198 --- /dev/null +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -0,0 +1,22 @@ +package gensym.wasm + +import org.scalatest.FunSuite + +import lms.core.stub.Adapter + +import gensym.wasm.parser._ +import gensym.wasm.miniwasm._ + +class TestStagedEval extends FunSuite { + def testFile(filename: String, main: Option[String] = None) = { + val module = Parser.parseFile(filename) + val partialEvaluator = StagedEvaluator(ModuleInstance(module)) + val block = partialEvaluator.codegen(main) + println(Adapter.g) + println(block) + } + + test("push-drop") { + testFile("./benchmarks/wasm/staged/push-drop.wat") + } +} From a11a7ff561175561786aa420768a68b8e629e07b Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 25 Apr 2025 23:44:05 +0800 Subject: [PATCH 02/40] compose all parts --- src/main/scala/wasm/StagedMiniWasm.scala | 126 +++++++++++++------ src/test/scala/genwasym/TestStagedEval.scala | 8 +- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 26691b0e..f7416199 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -4,38 +4,47 @@ import scala.collection.mutable.ArrayBuffer import lms.core.stub.Adapter import lms.core.virtualize -import lms.core.stub.Base +import lms.macros.SourceContext +import lms.core.stub.{Base, ScalaGenBase} +import lms.core.Backend._ import lms.core.Backend.{Block => LMSBlock} import gensym.wasm.ast._ import gensym.wasm.ast.{Const => ConstInstr} +import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase} -case class StagedEvaluator(module: ModuleInstance) extends Base { +@virtualize +trait StagedWasmEvaluator extends SAIOps { + def module: ModuleInstance + // NOTE: we don't need the following statements anymore, but where are they initialized? // reset and initialize the internal state of Adapter - Adapter.resetState - Adapter.g = Adapter.mkGraphBuilder + // Adapter.resetState + // Adapter.g = Adapter.mkGraphBuilder - type Stack = Rep[List[Value]] - type Cont[A] = Stack => Rep[A] + trait Stack + type Cont[A] = Rep[Stack => A] type Trail[A] = List[Cont[A]] // Ans should be instantiated to something like Int, Unit, etc, which is the result type of staged program - def eval[Ans](insts: List[Instr], - stack: Stack, - frame: Rep[Frame], - kont: Cont[Ans], - trail: Trail[Ans]): Rep[Ans] = { - if (insts.isEmpty) return kont(stack) - val (inst, rest) = (insts.head, insts.tail) - inst match { - case Drop => eval(rest, stack.tail, frame, kont, trail) - // Why this cons operation compiled? does anything could be casted to Rep? - case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) - case _ => "todo-op".reflectWith() - } + def eval(insts: List[Instr], + stack: Rep[Stack], + frame: Rep[Frame], + kont: Cont[Unit], + trail: Trail[Unit]): Rep[Unit] = { + if (insts.isEmpty) return kont(stack) + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => eval(rest, stack.tail, frame, kont, trail) + case ConstInstr(num) => eval(rest, (num: Rep[Num]) :: stack, frame, kont, trail) + // case LocalGet(i) => + // eval(rest, frame.locals(i) :: stack, frame, kont, trail) + case _ => + val noOp = "todo-op".reflectCtrlWith() + eval(rest, noOp :: stack, frame, kont, trail) + } } - def evalTop[Ans](kont: Cont[Ans], main: Option[String]): Rep[Ans] = { + def evalTop(kont: Cont[Unit], main: Option[String]): Rep[Unit] = { val funBody: FuncBodyDef = main match { case Some(func_name) => module.defs.flatMap({ @@ -47,7 +56,7 @@ case class StagedEvaluator(module: ModuleInstance) extends Base { } case _ => None }).head - case None => + case None => val startIds = module.defs.flatMap { case Start(id) => Some(id) case _ => None @@ -61,38 +70,75 @@ case class StagedEvaluator(module: ModuleInstance) extends Base { } val (instrs, localSize) = (funBody.body, funBody.locals.size) val frame = Frame(ArrayBuffer.fill(localSize)(I32V(0))) - eval(instrs, List(), frame, kont, List(kont)) + eval(instrs, emptyStack, unit(frame), kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error } def evalTop(main: Option[String]): Rep[Unit] = { - val haltK: Cont[Unit] = stack => { - "no-op".reflectWith() + val haltK: Rep[Stack] => Rep[Unit] = stack => { + "no-op".reflectCtrlWith() } - evalTop(haltK, main) + evalTop(fun(haltK), main) } - def codegen(main: Option[String]): LMSBlock = { - Adapter.g.reify( { Unwrap(evalTop(main)) } ) + def emptyStack: Rep[Stack] = { + "empty-stack".reflectWith() } - // The stack should be allocated on the stack to get optimal performance - implicit class StackOps(stack: Stack) { - def tail(): Stack = { - "value-stack-tail".reflectWith(stack) + // TODO: The stack should be allocated on the stack to get optimal performance + implicit class StackOps(stack: Rep[Stack]) { + def tail(): Rep[Stack] = { + "stack-tail".reflectCtrlWith(stack) + } + + def ::[A](v: Rep[A]): Rep[Stack] = { + "stack-cons".reflectCtrlWith(v, stack) } + } +} +trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { + override def traverse(n: Node): Unit = n match { + case _ => super.traverse(n) + } - def ::[A](v: Rep[A]): Stack = { - "value-stack-cons".reflectWith(v, stack) + // code generation for pure nodes + override def shallow(n: Node): Unit = n match { + case Node(_, "stack-cons", List(v, stack), _) => + shallow(stack); emit(".push("); shallow(v); emit(")") + case Node(_, "stack-tail", List(stack), _) => + shallow(stack); emit(".pop()") + case Node(_, "empty-stack", _, _) => + emit("new Stack()") + case _ => super.shallow(n) + } +} +trait WasmCompilerDriver[A, B] + extends SAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmScalaGen { + val IR: q.type = q + import IR._ + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Stack")) "Stack" + else super.remap(m) } } - // directly specify the translated operation - implicit class StringOps(op: String) { - def reflectWith[T: Manifest](rs: Rep[_]*): Rep[T] = { - val result = rs.map(Unwrap) - Predef.println(s"reflectWith: $op, $result") - val result1 = Adapter.g.reflect(op, result:_*) - Wrap[T](result1) + override val prelude = + """ + object Prelude { + } + import Prelude._ + """ +} + +object PartialEvaluator { + def apply(moduleInst: ModuleInstance, main: Option[String]): String = { + println(s"Now compiling wasm module with entry function $main") + val code = new WasmCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main) + } } + code.code } } diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 39e5c198..cc7197f0 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -9,11 +9,9 @@ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { def testFile(filename: String, main: Option[String] = None) = { - val module = Parser.parseFile(filename) - val partialEvaluator = StagedEvaluator(ModuleInstance(module)) - val block = partialEvaluator.codegen(main) - println(Adapter.g) - println(block) + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val code = PartialEvaluator(moduleInst, None) + println(code) } test("push-drop") { From 9408c85b848ea91b74610d4b8b9f12558abd6bec Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 01:15:02 +0800 Subject: [PATCH 03/40] Frame should be opaque --- benchmarks/wasm/staged/push-drop.wat | 3 +++ src/main/scala/wasm/StagedMiniWasm.scala | 26 +++++++++++++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index eec1f816..19a3897e 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -1,7 +1,10 @@ (module $push-drop (func $real_main (type 1) (result i32) + (local i32 i32) i32.const 2 i32.const 2 + local.get 0 + local.get 1 drop drop) (start 0)) \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index f7416199..d1a3c095 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -25,6 +25,8 @@ trait StagedWasmEvaluator extends SAIOps { type Cont[A] = Rep[Stack => A] type Trail[A] = List[Cont[A]] + trait Frame + // Ans should be instantiated to something like Int, Unit, etc, which is the result type of staged program def eval(insts: List[Instr], stack: Rep[Stack], @@ -35,9 +37,9 @@ trait StagedWasmEvaluator extends SAIOps { val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => eval(rest, stack.tail, frame, kont, trail) - case ConstInstr(num) => eval(rest, (num: Rep[Num]) :: stack, frame, kont, trail) - // case LocalGet(i) => - // eval(rest, frame.locals(i) :: stack, frame, kont, trail) + case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) + case LocalGet(i) => + eval(rest, frame.locals(i) :: stack, frame, kont, trail) case _ => val noOp = "todo-op".reflectCtrlWith() eval(rest, noOp :: stack, frame, kont, trail) @@ -69,8 +71,8 @@ trait StagedWasmEvaluator extends SAIOps { } } val (instrs, localSize) = (funBody.body, funBody.locals.size) - val frame = Frame(ArrayBuffer.fill(localSize)(I32V(0))) - eval(instrs, emptyStack, unit(frame), kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error + val frame = frameOf(localSize) + eval(instrs, emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error } def evalTop(main: Option[String]): Rep[Unit] = { @@ -80,6 +82,8 @@ trait StagedWasmEvaluator extends SAIOps { evalTop(fun(haltK), main) } + + // stack creation and operations def emptyStack: Rep[Stack] = { "empty-stack".reflectWith() } @@ -94,6 +98,18 @@ trait StagedWasmEvaluator extends SAIOps { "stack-cons".reflectCtrlWith(v, stack) } } + + // frame creation and operations + def frameOf(size: Int): Rep[Frame] = { + "frame-of".reflectWith(size) + } + + implicit class FrameOps(frame: Rep[Frame]) { + + def locals(i: Int): Rep[Num] = { + "frame-locals".reflectCtrlWith(frame, i) + } + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { From 60c782b36b90d3216c9f76068a459737f1d86a8f Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 20:20:45 +0800 Subject: [PATCH 04/40] function call --- benchmarks/wasm/staged/push-drop.wat | 9 ++- src/main/scala/wasm/StagedMiniWasm.scala | 96 ++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 19a3897e..2c630da8 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -1,10 +1,15 @@ (module $push-drop - (func $real_main (type 1) (result i32) + (func (;0;) (type 1) (result i32) (local i32 i32) i32.const 2 i32.const 2 local.get 0 local.get 1 drop - drop) + drop + (call 1)) + (func (;1;) (type 1) (param i32 i32) (result i32) + (local i32 i32) + local.get 0 + local.get 1) (start 0)) \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index d1a3c095..f968b76c 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -1,6 +1,6 @@ package gensym.wasm.miniwasm -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import lms.core.stub.Adapter import lms.core.virtualize @@ -22,16 +22,19 @@ trait StagedWasmEvaluator extends SAIOps { // Adapter.g = Adapter.mkGraphBuilder trait Stack - type Cont[A] = Rep[Stack => A] - type Trail[A] = List[Cont[A]] + type Cont[A] = Stack => A + type Trail[A] = List[Rep[Cont[A]]] trait Frame - // Ans should be instantiated to something like Int, Unit, etc, which is the result type of staged program + // a cache storing the compiled code for each function, to reduce re-compilation + val compileCache = new HashMap[Int, Rep[(Stack, Frame, Cont[Unit]) => Unit]] + + // NOTE: We don't support Ans type polymorphism yet def eval(insts: List[Instr], stack: Rep[Stack], frame: Rep[Frame], - kont: Cont[Unit], + kont: Rep[Cont[Unit]], trail: Trail[Unit]): Rep[Unit] = { if (insts.isEmpty) return kont(stack) val (inst, rest) = (insts.head, insts.tail) @@ -40,13 +43,68 @@ trait StagedWasmEvaluator extends SAIOps { case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) case LocalGet(i) => eval(rest, frame.locals(i) :: stack, frame, kont, trail) - case _ => + case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) + case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) + case _ => val noOp = "todo-op".reflectCtrlWith() eval(rest, noOp :: stack, frame, kont, trail) } } - def evalTop(kont: Cont[Unit], main: Option[String]): Rep[Unit] = { + def evalCall(rest: List[Instr], + stack: Rep[Stack], + frame: Rep[Frame], + kont: Rep[Cont[Unit]], + trail: Trail[Unit], + funcIndex: Int, + isTail: Boolean): Rep[Unit] = { + module.funcs(funcIndex) match { + case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => + val args = stack.take(ty.inps.size).reverse + val newStack = stack.drop(ty.inps.size) + val newFrame = frameOf(ty.inps.size + locals.size).put(args) + val callee = + if (compileCache.contains(funcIndex)) { + compileCache(funcIndex) + } else { + val callee = fun( + (stack: Rep[Stack], frame: Rep[Frame], kont: Rep[Cont[Unit]]) => { + eval(body, stack, frame, kont, kont::Nil):Rep[Unit] + } + ) + compileCache(funcIndex) = callee + callee + } + if (isTail) + // when tail call, share the continuation for returning with the callee + callee(emptyStack, newFrame, kont) + else { + val restK = fun( + (retStack: Rep[Stack]) => + eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail) + ) + // We make a new trail by `restK`, since function creates a new block to escape + // (more or less like `return`) + callee(emptyStack, newFrame, kont) + } + // TODO: Support imported functions + // case Import("console", "log", _) => + // //println(s"[DEBUG] current stack: $stack") + // val I32V(v) :: newStack = stack + // println(v) + // eval(rest, newStack, frame, kont, trail) + // case Import("spectest", "print_i32", _) => + // //println(s"[DEBUG] current stack: $stack") + // val I32V(v) :: newStack = stack + // println(v) + // eval(rest, newStack, frame, kont, trail) + case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") + case _ => throw new Exception(s"Definition at $funcIndex is not callable") + } + } + + + def evalTop(kont: Rep[Cont[Unit]], main: Option[String]): Rep[Unit] = { val funBody: FuncBodyDef = main match { case Some(func_name) => module.defs.flatMap({ @@ -97,6 +155,22 @@ trait StagedWasmEvaluator extends SAIOps { def ::[A](v: Rep[A]): Rep[Stack] = { "stack-cons".reflectCtrlWith(v, stack) } + + def ++(v: Rep[Stack]): Rep[Stack] = { + "stack-append".reflectCtrlWith(stack, v) + } + + def take(n: Int): Rep[Stack] = { + "stack-take".reflectWith(stack, n) + } + + def drop(n: Int): Rep[Stack] = { + "stack-drop".reflectWith(stack, n) + } + + def reverse: Rep[Stack] = { + "stack-reverse".reflectWith(stack) + } } // frame creation and operations @@ -107,8 +181,13 @@ trait StagedWasmEvaluator extends SAIOps { implicit class FrameOps(frame: Rep[Frame]) { def locals(i: Int): Rep[Num] = { - "frame-locals".reflectCtrlWith(frame, i) + "frame-get".reflectCtrlWith(frame, i) + } + + def put(args: Rep[Stack]): Rep[Frame] = { + "frame-put".reflectCtrlWith(frame, args) } + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { @@ -134,6 +213,7 @@ trait WasmCompilerDriver[A, B] import IR._ override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Stack")) "Stack" + else if(m.toString.endsWith("Frame")) "Frame" else super.remap(m) } } From 442d8d1d7015645f96470a2a9385281fbbbf85fc Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 21:20:11 +0800 Subject: [PATCH 05/40] factor out getFuncType --- src/main/scala/wasm/AST.scala | 11 ++++++++++- src/main/scala/wasm/ConcolicMiniWasm.scala | 21 ++++++--------------- src/main/scala/wasm/MiniWasm.scala | 15 +++------------ 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/main/scala/wasm/AST.scala b/src/main/scala/wasm/AST.scala index c59eefc9..274b2b50 100644 --- a/src/main/scala/wasm/AST.scala +++ b/src/main/scala/wasm/AST.scala @@ -270,7 +270,16 @@ case class RefType(kind: HeapType) extends ValueType case class GlobalType(ty: ValueType, mut: Boolean) extends WasmType -abstract class BlockType extends WIR +abstract class BlockType extends WIR { + def funcType: FuncType = + this match { + case VarBlockType(_, None) => + ??? // TODO: fill this branch until we handle type index correctly + case VarBlockType(_, Some(tipe)) => tipe + case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) + case ValBlockType(None) => FuncType(List(), List(), List()) + } +} case class VarBlockType(index: Int, tipe: Option[FuncType]) extends BlockType case class ValBlockType(tipe: Option[ValueType]) extends BlockType; diff --git a/src/main/scala/wasm/ConcolicMiniWasm.scala b/src/main/scala/wasm/ConcolicMiniWasm.scala index 849fd831..fef469ec 100644 --- a/src/main/scala/wasm/ConcolicMiniWasm.scala +++ b/src/main/scala/wasm/ConcolicMiniWasm.scala @@ -229,15 +229,6 @@ object Primitives { case NumType(F32Type) => F32V(rng.nextFloat()) case NumType(F64Type) => F64V(rng.nextDouble()) } - - def getFuncType(ty: BlockType): FuncType = - ty match { - case VarBlockType(_, None) => - ??? // TODO: fill this branch until we handle type index correctly - case VarBlockType(_, Some(tipe)) => tipe - case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) - case ValBlockType(None) => FuncType(List(), List(), List()) - } } case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value], symLocals: ArrayBuffer[SymVal]) @@ -383,7 +374,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, concStack, symStack, frame, kont, trail) case Unreachable => throw new RuntimeException("Unreachable") case Block(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputSize, outputSize) = (funcTy.inps.size, funcTy.out.size) val (inputs, restStack) = concStack.splitAt(inputSize) val (symInputs, restSymStack) = symStack.splitAt(inputSize) @@ -391,7 +382,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, retStack.take(outputSize) ++ restStack, retSymStack.take(outputSize) ++ restSymStack, frame, kont, trail)(tree) eval(inner, inputs, symInputs, frame, restK, restK :: trail) case Loop(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputSize, outputSize) = (funcTy.inps.size, funcTy.out.size) val (inputs, restStack) = concStack.splitAt(inputSize) val (symInputs, restSymStack) = symStack.splitAt(inputSize) @@ -404,9 +395,9 @@ case class Evaluator(module: ModuleInstance) { val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { - // if this is a concrete value, we don't need to put + // if this is a concrete value, we don't need to put (tree, tree) - } else { + } else { val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) (ifElseNode.thenNode, ifElseNode.elseNode) } @@ -422,9 +413,9 @@ case class Evaluator(module: ModuleInstance) { val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { - // if this is a concrete value, we don't need to put + // if this is a concrete value, we don't need to put (tree, tree) - } else { + } else { val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) (ifElseNode.thenNode, ifElseNode.elseNode) } diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 11eb301b..2a5abe6d 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -229,15 +229,6 @@ object Primtives { case VecType(kind) => ??? case RefType(kind) => RefNullV(kind) } - - def getFuncType(ty: BlockType): FuncType = - ty match { - case VarBlockType(_, None) => - ??? // TODO: fill this branch until we handle type index correctly - case VarBlockType(_, Some(tipe)) => tipe - case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) - case ValBlockType(None) => FuncType(List(), List(), List()) - } } case class Frame(locals: ArrayBuffer[Value]) @@ -380,7 +371,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, stack, frame, kont, trail) case Unreachable => throw Trap() case Block(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) @@ -389,7 +380,7 @@ case class Evaluator(module: ModuleInstance) { // We construct two continuations, one for the break (to the begining of the loop), // and one for fall-through to the next instruction following the syntactic structure // of the program. - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) @@ -397,7 +388,7 @@ case class Evaluator(module: ModuleInstance) { eval(inner, retStack.take(funcTy.inps.size), frame, restK, loop _ :: trail) loop(inputs) case If(ty, thn, els) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val I32V(cond) :: newStack = stack val inner = if (cond != 0) thn else els val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) From ab679f36c4181fba8bf25866f992bc59ae35f71b Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 22:15:15 +0800 Subject: [PATCH 06/40] fix: use restK when non-tail call --- src/main/scala/wasm/StagedMiniWasm.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index f968b76c..bb3a7291 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -85,7 +85,7 @@ trait StagedWasmEvaluator extends SAIOps { ) // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) - callee(emptyStack, newFrame, kont) + callee(emptyStack, newFrame, restK) } // TODO: Support imported functions // case Import("console", "log", _) => From 8ba8657782bf1a7cea7ce1ae7e8ff81847089692 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 22:43:55 +0800 Subject: [PATCH 07/40] compile Block-like instructions(if-else, loop, block) --- benchmarks/wasm/staged/push-drop.wat | 11 +++++- src/main/scala/wasm/StagedMiniWasm.scala | 46 ++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 2c630da8..a32fb24a 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -7,7 +7,16 @@ local.get 1 drop drop - (call 1)) + (call 1) + i32.const 3 + if (result i32) ;; label = @1 + i32.const 1 + else + local.get 1 + end + (loop + i32.const 4) + ) (func (;1;) (type 1) (param i32 i32) (result i32) (local i32 i32) local.get 0 diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index bb3a7291..025a55df 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -10,7 +10,7 @@ import lms.core.Backend._ import lms.core.Backend.{Block => LMSBlock} import gensym.wasm.ast._ -import gensym.wasm.ast.{Const => ConstInstr} +import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase} @virtualize @@ -40,9 +40,41 @@ trait StagedWasmEvaluator extends SAIOps { val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => eval(rest, stack.tail, frame, kont, trail) - case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) + case WasmConst(num) => eval(rest, num :: stack, frame, kont, trail) case LocalGet(i) => eval(rest, frame.locals(i) :: stack, frame, kont, trail) + case WasmBlock(ty, inner) => + val funcTy = ty.funcType + val (inputs, restStack) = stack.splitAt(funcTy.inps.size) + val restK = fun( + (retStack: Rep[Stack]) => + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) + ) + eval(inner, inputs, frame, restK, restK :: trail) + case Loop(ty, inner) => + val funcTy = ty.funcType + val (inputs, restStack) = stack.splitAt(funcTy.inps.size) + val restK = fun( + (retStack: Rep[Stack]) => + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) + ) + def loop(retStack: Rep[Stack]): Rep[Unit] = + eval(inner, retStack.take(funcTy.inps.size), frame, restK, fun(loop _) :: trail) + loop(inputs) + case If(ty, thn, els) => + val funcTy = ty.funcType + val (cond, newStack) = (stack.head, stack.tail) + val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) + // TODO: can we avoid code duplication here? + val restK = fun( + (retStack: Rep[Stack]) => + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) + ) + if (cond != 0) { + eval(thn, inputs, frame, restK, restK :: trail) + } else { + eval(els, inputs, frame, restK, restK :: trail) + } case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) case _ => @@ -148,7 +180,11 @@ trait StagedWasmEvaluator extends SAIOps { // TODO: The stack should be allocated on the stack to get optimal performance implicit class StackOps(stack: Rep[Stack]) { - def tail(): Rep[Stack] = { + def head: Rep[Num] = { + "stack-head".reflectCtrlWith(stack) + } + + def tail: Rep[Stack] = { "stack-tail".reflectCtrlWith(stack) } @@ -171,6 +207,10 @@ trait StagedWasmEvaluator extends SAIOps { def reverse: Rep[Stack] = { "stack-reverse".reflectWith(stack) } + + def splitAt(n: Int): (Rep[Stack], Rep[Stack]) = { + (take(n), drop(n)) + } } // frame creation and operations From e2cc801df0b5998242de2be8d2f91d5fa737fd19 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 01:25:56 +0800 Subject: [PATCH 08/40] branching instructions --- benchmarks/wasm/staged/push-drop.wat | 13 ++++++++++++- src/main/scala/wasm/StagedMiniWasm.scala | 23 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index a32fb24a..196fbc19 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -14,8 +14,19 @@ else local.get 1 end + (block + (block + i32.const 4 + i32.const 2 + ;; br_table 0 0 ;; the compilation of br_table is problematic now + ) + ) + (loop - i32.const 4) + i32.const 5 + br 0) + return + i32.const 6 ) (func (;1;) (type 1) (param i32 i32) (result i32) (local i32 i32) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 025a55df..02c7dbac 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -75,6 +75,22 @@ trait StagedWasmEvaluator extends SAIOps { } else { eval(els, inputs, frame, restK, restK :: trail) } + case Br(label) => + trail(label)(stack) + case BrIf(label) => + val (cond, newStack) = (stack.head, stack.tail) + if (cond != 0) trail(label)(newStack) + else eval(rest, newStack, frame, kont, trail) + case BrTable(labels, default) => + val (cond, newStack) = (stack.head, stack.tail) + if (cond.toInt < labels.length) { + var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) + val goto: Rep[Cont[Unit]] = targets(cond.toInt) + goto(newStack) // TODO: this line will trigger an exception + } else { + trail(default)(newStack) + } + case Return => trail.last(stack) case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) case _ => @@ -229,6 +245,13 @@ trait StagedWasmEvaluator extends SAIOps { } } + + // runtime Num type + implicit class NumOps(num: Rep[Num]) { + def toInt: Rep[Int] = { + "num-to-int".reflectWith(num) + } + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { From fa3d6280330bb357277323d0d774b1b8b7acf94f Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 14:13:48 +0800 Subject: [PATCH 09/40] local set --- benchmarks/wasm/staged/push-drop.wat | 2 ++ src/main/scala/wasm/StagedMiniWasm.scala | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 196fbc19..5b6c500b 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -5,6 +5,8 @@ i32.const 2 local.get 0 local.get 1 + local.set 0 + local.tee 1 drop drop (call 1) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 02c7dbac..9f5ef991 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -43,6 +43,14 @@ trait StagedWasmEvaluator extends SAIOps { case WasmConst(num) => eval(rest, num :: stack, frame, kont, trail) case LocalGet(i) => eval(rest, frame.locals(i) :: stack, frame, kont, trail) + case LocalSet(i) => + val (v, newStack) = (stack.head, stack.tail) + frame(i) = v + eval(rest, newStack, frame, kont, trail) + case LocalTee(i) => + val (v, _) = (stack.head, stack.tail) + frame(i) = v + eval(rest, stack, frame, kont, trail) case WasmBlock(ty, inner) => val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) @@ -244,6 +252,9 @@ trait StagedWasmEvaluator extends SAIOps { "frame-put".reflectCtrlWith(frame, args) } + def update(i: Int, value: Rep[Num]) = { + "frame-update".reflectCtrlWith(frame, i, value) + } } // runtime Num type From d78a96b5afe5f619c02eb75e15ca875fd58f2835 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 15:18:44 +0800 Subject: [PATCH 10/40] operators --- benchmarks/wasm/staged/push-drop.wat | 2 + src/main/scala/wasm/StagedMiniWasm.scala | 110 ++++++++++++++++++++++- 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 5b6c500b..903b771d 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -9,6 +9,8 @@ local.tee 1 drop drop + i32.add + nop (call 1) i32.const 3 if (result i32) ;; label = @1 diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 9f5ef991..5112cecb 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -51,6 +51,21 @@ trait StagedWasmEvaluator extends SAIOps { val (v, _) = (stack.head, stack.tail) frame(i) = v eval(rest, stack, frame, kont, trail) + case Nop => + eval(rest, stack, frame, kont, trail) + case Unreachable => unreachable() + case Test(op) => + val (v, newStack) = (stack.head, stack.tail) + eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail) + case Unary(op) => + val (v, newStack) = (stack.head, stack.tail) + eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail) + case Binary(op) => + val (v2, v1, newStack) = (stack.head, stack.tail.head, stack.tail.tail) + eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail) + case Compare(op) => + val (v2, v1, newStack) = (stack.head, stack.tail.head, stack.tail.tail) + eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail) case WasmBlock(ty, inner) => val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) @@ -159,6 +174,41 @@ trait StagedWasmEvaluator extends SAIOps { } } + def evalTestOp(op: TestOp, value: Rep[Num]): Rep[Num] = op match { + case Eqz(_) => if (value.toInt == 0) I32(1) else I32(0) + } + + def evalUnaryOp(op: UnaryOp, value: Rep[Num]): Rep[Num] = op match { + case Clz(_) => value.clz() + case Ctz(_) => value.ctz() + case Popcnt(_) => value.popcnt() + case _ => ??? + } + + def evalBinOp(op: BinOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { + case Add(_) => v1 + v2 + case Mul(_) => v1 * v2 + case Sub(_) => v1 - v2 + case Shl(_) => v1 << v2 + // case ShrS(_) => v1 >> v2 // TODO: signed shift right + case ShrU(_) => v1 >> v2 + case And(_) => v1 & v2 + case _ => ??? + } + + def evalRelOp(op: RelOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { + case Eq(_) => v1 numEq v2 + case Ne(_) => v1 numNe v2 + case LtS(_) => v1 < v2 + case LtU(_) => v1 ltu v2 + case GtS(_) => v1 > v2 + case GtU(_) => v1 gtu v2 + case LeS(_) => v1 <= v2 + case LeU(_) => v1 leu v2 + case GeS(_) => v1 >= v2 + case GeU(_) => v1 geu v2 + case _ => ??? + } def evalTop(kont: Rep[Cont[Unit]], main: Option[String]): Rep[Unit] = { val funBody: FuncBodyDef = main match { @@ -202,6 +252,19 @@ trait StagedWasmEvaluator extends SAIOps { "empty-stack".reflectWith() } + // call unreachable + def unreachable(): Rep[Unit] = { + "unreachable".reflectCtrlWith() + } + + def I32(i: Rep[Int]): Rep[Num] = { + "I32V".reflectWith(i) + } + + def I64(i: Rep[Long]): Rep[Num] = { + "I64V".reflectWith(i) + } + // TODO: The stack should be allocated on the stack to get optimal performance implicit class StackOps(stack: Rep[Stack]) { def head: Rep[Num] = { @@ -255,13 +318,54 @@ trait StagedWasmEvaluator extends SAIOps { def update(i: Int, value: Rep[Num]) = { "frame-update".reflectCtrlWith(frame, i, value) } + } // runtime Num type implicit class NumOps(num: Rep[Num]) { - def toInt: Rep[Int] = { - "num-to-int".reflectWith(num) - } + + def toInt: Rep[Int] = "num-to-int".reflectWith(num) + + def clz(): Rep[Num] = "unary-clz".reflectWith(num) + + def ctz(): Rep[Num] = "unary-ctz".reflectWith(num) + + def popcnt(): Rep[Num] = "unary-popcnt".reflectWith(num) + + def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith(num, rhs) + + def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith(num, rhs) + + def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith(num, rhs) + + def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith(num, rhs) + + def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith(num, rhs) + + def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith(num, rhs) + + def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith(num, rhs) + + def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith(num, rhs) + + def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith(num, rhs) + + def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith(num, rhs) + + def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith(num, rhs) + + def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith(num, rhs) + + def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith(num, rhs) + + def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith(num, rhs) + + def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith(num, rhs) + + def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith(num, rhs) + + def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith(num, rhs) + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { From dac7e1cef393609c1cf59ee75b3ec7837c35b085 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 15:46:35 +0800 Subject: [PATCH 11/40] global instructions --- benchmarks/wasm/staged/push-drop.wat | 4 ++++ src/main/scala/wasm/StagedMiniWasm.scala | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 903b771d..db7b18bd 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -1,4 +1,5 @@ (module $push-drop + (global (;0;) (mut i32) (i32.const 1048576)) (func (;0;) (type 1) (result i32) (local i32 i32) i32.const 2 @@ -12,7 +13,10 @@ i32.add nop (call 1) + global.get 1 i32.const 3 + global.set 2 ;; TODO: this line was compiled to global.get, fix the parser! + if (result i32) ;; label = @1 i32.const 1 else diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 5112cecb..b69a7fbc 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -51,6 +51,15 @@ trait StagedWasmEvaluator extends SAIOps { val (v, _) = (stack.head, stack.tail) frame(i) = v eval(rest, stack, frame, kont, trail) + case GlobalGet(i) => + eval(rest, Global.globalGet(i) :: stack, frame, kont, trail) + case GlobalSet(i) => + val (value, newStack) = (stack.head, stack.tail) + module.globals(i).ty match { + case GlobalType(tipe, true) => Global.globalSet(i, value) + case _ => throw new Exception("Cannot set immutable global") + } + eval(rest, newStack, frame, kont, trail) case Nop => eval(rest, stack, frame, kont, trail) case Unreachable => unreachable() @@ -265,6 +274,17 @@ trait StagedWasmEvaluator extends SAIOps { "I64V".reflectWith(i) } + // global read/write + object Global{ + def globalGet(i: Int): Rep[Num] = { + "global-get".reflectWith(i) + } + + def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { + "global-set".reflectCtrlWith(i, value) + } + } + // TODO: The stack should be allocated on the stack to get optimal performance implicit class StackOps(stack: Rep[Stack]) { def head: Rep[Num] = { From cc65f6358d80b7c38acdad3c43337344354e80d5 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 16:06:21 +0800 Subject: [PATCH 12/40] placeholder for mem instructions --- src/main/scala/wasm/StagedMiniWasm.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index b69a7fbc..0373f49f 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -60,6 +60,9 @@ trait StagedWasmEvaluator extends SAIOps { case _ => throw new Exception("Cannot set immutable global") } eval(rest, newStack, frame, kont, trail) + case MemorySize => ??? + case MemoryGrow => ??? + case MemoryFill => ??? case Nop => eval(rest, stack, frame, kont, trail) case Unreachable => unreachable() From cf3063a4bda619bac6106e05577c1fb5c544d297 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 16:26:11 +0800 Subject: [PATCH 13/40] scala code generation --- src/main/scala/wasm/StagedMiniWasm.scala | 61 +++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 0373f49f..80bec0d9 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -42,7 +42,7 @@ trait StagedWasmEvaluator extends SAIOps { case Drop => eval(rest, stack.tail, frame, kont, trail) case WasmConst(num) => eval(rest, num :: stack, frame, kont, trail) case LocalGet(i) => - eval(rest, frame.locals(i) :: stack, frame, kont, trail) + eval(rest, frame.get(i) :: stack, frame, kont, trail) case LocalSet(i) => val (v, newStack) = (stack.head, stack.tail) frame(i) = v @@ -330,7 +330,7 @@ trait StagedWasmEvaluator extends SAIOps { implicit class FrameOps(frame: Rep[Frame]) { - def locals(i: Int): Rep[Num] = { + def get(i: Int): Rep[Num] = { "frame-get".reflectCtrlWith(frame, i) } @@ -393,17 +393,74 @@ trait StagedWasmEvaluator extends SAIOps { } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { + case Node(_, "frame-update", List(frame, i, value), _) => + // TODO: what is the protocol of automatic new line insertion? + shallow(frame); emit(".update("); shallow(i); emit(", "); shallow(value); emit(")\n") + case Node(_, "global-set", List(i, value), _) => + shallow(i); emit(".globalSet("); shallow(value); emit(")") case _ => super.traverse(n) } // code generation for pure nodes override def shallow(n: Node): Unit = n match { + case Node(_, "stack-take", List(stack, n), _) => + shallow(stack); emit(".take("); shallow(n); emit(")") + case Node(_, "stack-drop", List(stack, n), _) => + shallow(stack); emit(".drop("); shallow(n); emit(")") + case Node(_, "stack-append", List(stack1, stack2), _) => + shallow(stack1); emit(".++("); shallow(stack2); emit(")") + case Node(_, "stack-head", List(stack), _) => + shallow(stack); emit(".head") + case Node(_, "stack-reverse", List(stack), _) => + shallow(stack); emit(".reverse") case Node(_, "stack-cons", List(v, stack), _) => shallow(stack); emit(".push("); shallow(v); emit(")") case Node(_, "stack-tail", List(stack), _) => shallow(stack); emit(".pop()") case Node(_, "empty-stack", _, _) => emit("new Stack()") + case Node(_, "frame-of", List(size), _) => + emit("new Frame("); shallow(size); emit(")") + case Node(_, "frame-get", List(frame, i), _) => + shallow(frame); emit("("); shallow(i); emit(")") + case Node(_, "frame-put", List(frame, args), _) => + shallow(frame); emit(".put("); shallow(args); emit(")") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) case _ => super.shallow(n) } } From 80bfa682c1d06d2cbcb42517e6b89239171129de Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 20:03:56 +0800 Subject: [PATCH 14/40] some imported function --- src/main/scala/wasm/StagedMiniWasm.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 80bec0d9..4e6f7107 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -170,17 +170,12 @@ trait StagedWasmEvaluator extends SAIOps { // (more or less like `return`) callee(emptyStack, newFrame, restK) } - // TODO: Support imported functions - // case Import("console", "log", _) => - // //println(s"[DEBUG] current stack: $stack") - // val I32V(v) :: newStack = stack - // println(v) - // eval(rest, newStack, frame, kont, trail) - // case Import("spectest", "print_i32", _) => - // //println(s"[DEBUG] current stack: $stack") - // val I32V(v) :: newStack = stack - // println(v) - // eval(rest, newStack, frame, kont, trail) + case Import("console", "log", _) + | Import("spectest", "print_i32", _) => + //println(s"[DEBUG] current stack: $stack") + val (v, newStack) = (stack.head, stack.tail) + println(v) + eval(rest, newStack, frame, kont, trail) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } From 881eda45df5c730a80938b2d7e0e1c82656febaa Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 20:04:26 +0800 Subject: [PATCH 15/40] polish --- .../staged/{push-drop.wat => scratch.wat} | 2 ++ src/main/scala/wasm/ConcolicMiniWasm.scala | 8 ++--- src/main/scala/wasm/StagedMiniWasm.scala | 31 +++++++++++-------- src/test/scala/genwasym/TestStagedEval.scala | 2 +- 4 files changed, 25 insertions(+), 18 deletions(-) rename benchmarks/wasm/staged/{push-drop.wat => scratch.wat} (87%) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/scratch.wat similarity index 87% rename from benchmarks/wasm/staged/push-drop.wat rename to benchmarks/wasm/staged/scratch.wat index db7b18bd..2b0b3ede 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/scratch.wat @@ -1,3 +1,5 @@ +;; this file contains some wasm instructions to test if the compiler works, +;; and its execution is meaningless. (module $push-drop (global (;0;) (mut i32) (i32.const 1048576)) (func (;0;) (type 1) (result i32) diff --git a/src/main/scala/wasm/ConcolicMiniWasm.scala b/src/main/scala/wasm/ConcolicMiniWasm.scala index fef469ec..fec869fe 100644 --- a/src/main/scala/wasm/ConcolicMiniWasm.scala +++ b/src/main/scala/wasm/ConcolicMiniWasm.scala @@ -395,9 +395,9 @@ case class Evaluator(module: ModuleInstance) { val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { - // if this is a concrete value, we don't need to put + // if this is a concrete value, we don't need to put (tree, tree) - } else { + } else { val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) (ifElseNode.thenNode, ifElseNode.elseNode) } @@ -413,9 +413,9 @@ case class Evaluator(module: ModuleInstance) { val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { - // if this is a concrete value, we don't need to put + // if this is a concrete value, we don't need to put (tree, tree) - } else { + } else { val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) (ifElseNode.thenNode, ifElseNode.elseNode) } diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 4e6f7107..b2c8a3eb 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -160,7 +160,7 @@ trait StagedWasmEvaluator extends SAIOps { } if (isTail) // when tail call, share the continuation for returning with the callee - callee(emptyStack, newFrame, kont) + callee(Stack.emptyStack, newFrame, kont) else { val restK = fun( (retStack: Rep[Stack]) => @@ -168,7 +168,7 @@ trait StagedWasmEvaluator extends SAIOps { ) // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) - callee(emptyStack, newFrame, restK) + callee(Stack.emptyStack, newFrame, restK) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => @@ -182,7 +182,7 @@ trait StagedWasmEvaluator extends SAIOps { } def evalTestOp(op: TestOp, value: Rep[Num]): Rep[Num] = op match { - case Eqz(_) => if (value.toInt == 0) I32(1) else I32(0) + case Eqz(_) => if (value.toInt == 0) Values.I32(1) else Values.I32(0) } def evalUnaryOp(op: UnaryOp, value: Rep[Num]): Rep[Num] = op match { @@ -243,7 +243,7 @@ trait StagedWasmEvaluator extends SAIOps { } val (instrs, localSize) = (funBody.body, funBody.locals.size) val frame = frameOf(localSize) - eval(instrs, emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error + eval(instrs, Stack.emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error } def evalTop(main: Option[String]): Rep[Unit] = { @@ -253,10 +253,11 @@ trait StagedWasmEvaluator extends SAIOps { evalTop(fun(haltK), main) } - // stack creation and operations - def emptyStack: Rep[Stack] = { - "empty-stack".reflectWith() + object Stack { + def emptyStack: Rep[Stack] = { + "empty-stack".reflectWith() + } } // call unreachable @@ -264,12 +265,15 @@ trait StagedWasmEvaluator extends SAIOps { "unreachable".reflectCtrlWith() } - def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectWith(i) - } + // runtime values + object Values { + def I32(i: Rep[Int]): Rep[Num] = { + "I32V".reflectWith(i) + } - def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectWith(i) + def I64(i: Rep[Long]): Rep[Num] = { + "I64V".reflectWith(i) + } } // global read/write @@ -383,9 +387,9 @@ trait StagedWasmEvaluator extends SAIOps { def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith(num, rhs) def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith(num, rhs) - } } + trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { case Node(_, "frame-update", List(frame, i, value), _) => @@ -459,6 +463,7 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case _ => super.shallow(n) } } + trait WasmCompilerDriver[A, B] extends SAIDriver[A, B] with StagedWasmEvaluator { q => override val codegen = new StagedWasmScalaGen { diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index cc7197f0..d8a12839 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -15,6 +15,6 @@ class TestStagedEval extends FunSuite { } test("push-drop") { - testFile("./benchmarks/wasm/staged/push-drop.wat") + testFile("./benchmarks/wasm/staged/scratch.wat") } } From 7a2bfd4ea55958ba6c397adac9c73b0496e2393e Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 20:04:35 +0800 Subject: [PATCH 16/40] ci --- .github/workflows/scala.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index 5610a96e..4677da77 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -78,3 +78,4 @@ jobs: sbt 'testOnly gensym.wasm.TestScriptRun' sbt 'testOnly gensym.wasm.TestConcolic' sbt 'testOnly gensym.wasm.TestDriver' + sbt 'testOnly gensym.wasm.TestStagedEval' From 294fcea1df3586ca71ab53e17a7dd277b9d55243 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 20:13:53 +0800 Subject: [PATCH 17/40] tweak --- src/main/scala/wasm/StagedMiniWasm.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index b2c8a3eb..9eb5efaa 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -145,7 +145,8 @@ trait StagedWasmEvaluator extends SAIOps { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => val args = stack.take(ty.inps.size).reverse val newStack = stack.drop(ty.inps.size) - val newFrame = frameOf(ty.inps.size + locals.size).put(args) + val newFrame = frameOf(ty.inps.size + locals.size) + newFrame.putAll(args) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) @@ -333,8 +334,8 @@ trait StagedWasmEvaluator extends SAIOps { "frame-get".reflectCtrlWith(frame, i) } - def put(args: Rep[Stack]): Rep[Frame] = { - "frame-put".reflectCtrlWith(frame, args) + def putAll(args: Rep[Stack]) = { + "frame-putAll".reflectCtrlWith(frame, args) } def update(i: Int, value: Rep[Num]) = { @@ -422,8 +423,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { emit("new Frame("); shallow(size); emit(")") case Node(_, "frame-get", List(frame, i), _) => shallow(frame); emit("("); shallow(i); emit(")") - case Node(_, "frame-put", List(frame, args), _) => - shallow(frame); emit(".put("); shallow(args); emit(")") + case Node(_, "frame-putAll", List(frame, args), _) => + shallow(frame); emit(".putAll("); shallow(args); emit(")") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") case Node(_, "binary-add", List(lhs, rhs), _) => From f450b5c68dbca114a5e20fd3fcb3ee60fc66c705 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 28 Apr 2025 00:18:24 +0800 Subject: [PATCH 18/40] try some simplification --- src/main/scala/wasm/StagedMiniWasm.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 9eb5efaa..2a50b6a5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -307,11 +307,13 @@ trait StagedWasmEvaluator extends SAIOps { } def take(n: Int): Rep[Stack] = { - "stack-take".reflectWith(stack, n) + if (n == 0) Stack.emptyStack + else "stack-take".reflectWith(stack, n) } def drop(n: Int): Rep[Stack] = { - "stack-drop".reflectWith(stack, n) + if (n == 0) stack + else "stack-drop".reflectWith(stack, n) } def reverse: Rep[Stack] = { From 336eec590537a80d68ec751a4aee9f6bf7920fb7 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 28 Apr 2025 01:32:54 +0800 Subject: [PATCH 19/40] improve runtime(the prelude) --- src/main/scala/wasm/StagedMiniWasm.scala | 78 ++++++++++++++++++++---- 1 file changed, 67 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 2a50b6a5..8fa27b2f 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -40,7 +40,7 @@ trait StagedWasmEvaluator extends SAIOps { val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => eval(rest, stack.tail, frame, kont, trail) - case WasmConst(num) => eval(rest, num :: stack, frame, kont, trail) + case WasmConst(num) => eval(rest, Values.lift(num) :: stack, frame, kont, trail) case LocalGet(i) => eval(rest, frame.get(i) :: stack, frame, kont, trail) case LocalSet(i) => @@ -105,7 +105,7 @@ trait StagedWasmEvaluator extends SAIOps { (retStack: Rep[Stack]) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) ) - if (cond != 0) { + if (cond != Values.I32(0)) { eval(thn, inputs, frame, restK, restK :: trail) } else { eval(els, inputs, frame, restK, restK :: trail) @@ -268,6 +268,13 @@ trait StagedWasmEvaluator extends SAIOps { // runtime values object Values { + def lift(num: Num): Rep[Num] = { + num match { + case I32V(i) => I32(i) + case I64V(i) => I64(i) + } + } + def I32(i: Rep[Int]): Rep[Num] = { "I32V".reflectWith(i) } @@ -298,7 +305,7 @@ trait StagedWasmEvaluator extends SAIOps { "stack-tail".reflectCtrlWith(stack) } - def ::[A](v: Rep[A]): Rep[Stack] = { + def ::[A](v: Rep[Num]): Rep[Stack] = { "stack-cons".reflectCtrlWith(v, stack) } @@ -416,11 +423,11 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case Node(_, "stack-reverse", List(stack), _) => shallow(stack); emit(".reverse") case Node(_, "stack-cons", List(v, stack), _) => - shallow(stack); emit(".push("); shallow(v); emit(")") + shallow(stack); emit(".::("); shallow(v); emit(")") case Node(_, "stack-tail", List(stack), _) => - shallow(stack); emit(".pop()") + shallow(stack); emit(".tail") case Node(_, "empty-stack", _, _) => - emit("new Stack()") + emit("Nil") case Node(_, "frame-of", List(size), _) => emit("new Frame("); shallow(size); emit(")") case Node(_, "frame-get", List(frame, i), _) => @@ -480,11 +487,60 @@ trait WasmCompilerDriver[A, B] } override val prelude = - """ - object Prelude { - } - import Prelude._ - """ + """ +object Prelude { + sealed abstract class Num { + def +(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(x + y) + case (I64V(x), I64V(y)) => I64V(x + y) + case _ => throw new RuntimeException("Invalid addition") + } + + def -(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(x - y) + case (I64V(x), I64V(y)) => I64V(x - y) + case _ => throw new RuntimeException("Invalid subtraction") + } + + def !=(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(if (x != y) 1 else 0) + case (I64V(x), I64V(y)) => I32V(if (x != y) 1 else 0) + case _ => throw new RuntimeException("Invalid inequality") + } + } + case class I32V(i: Int) extends Num + case class I64V(i: Long) extends Num + + + type Stack = List[Num] + + class Frame(val size: Int) { + private val data = new Array[Num](size) + def apply(i: Int): Num = data(i) + def update(i: Int, v: Num): Unit = data(i) = v + def putAll(xs: List[Num]): Unit = { + for (i <- 0 until xs.size) { + data(i) = xs(i) + } + } + } + + object Global { + // TODO: create global with specific size + private val globals = new Array[Num](10) + def globalGet(i: Int): Num = globals(i) + def globalSet(i: Int, v: Num): Unit = globals(i) = v + } +} +import Prelude._ + +object Main { + def main(args: Array[String]): Unit = { + val snippet = new Snippet() + snippet(()) + } +} +""" } object PartialEvaluator { From 6a666f35f584ab9abcdbd391f1d3de7a60ee5339 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 28 Apr 2025 22:16:16 +0800 Subject: [PATCH 20/40] some fixes --- src/main/scala/wasm/StagedMiniWasm.scala | 17 +++++++++++++---- src/test/scala/genwasym/TestStagedEval.scala | 4 ++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 8fa27b2f..208f8b95 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -114,7 +114,7 @@ trait StagedWasmEvaluator extends SAIOps { trail(label)(stack) case BrIf(label) => val (cond, newStack) = (stack.head, stack.tail) - if (cond != 0) trail(label)(newStack) + if (cond != Values.I32(0)) trail(label)(newStack) else eval(rest, newStack, frame, kont, trail) case BrTable(labels, default) => val (cond, newStack) = (stack.head, stack.tail) @@ -129,8 +129,8 @@ trait StagedWasmEvaluator extends SAIOps { case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) case _ => - val noOp = "todo-op".reflectCtrlWith() - eval(rest, noOp :: stack, frame, kont, trail) + val todo = "todo-op".reflectCtrlWith() + eval(rest, todo :: stack, frame, kont, trail) } } @@ -249,7 +249,7 @@ trait StagedWasmEvaluator extends SAIOps { def evalTop(main: Option[String]): Rep[Unit] = { val haltK: Rep[Stack] => Rep[Unit] = stack => { - "no-op".reflectCtrlWith() + "no-op".reflectCtrlWith[Unit]() } evalTop(fun(haltK), main) } @@ -470,6 +470,10 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "relation-geu", List(lhs, rhs), _) => shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt") + case Node(_, "no-op", _, _) => + emit("()") case _ => super.shallow(n) } } @@ -507,6 +511,11 @@ object Prelude { case (I64V(x), I64V(y)) => I32V(if (x != y) 1 else 0) case _ => throw new RuntimeException("Invalid inequality") } + + def toInt: Int = this match { + case I32V(i) => i + case I64V(i) => i.toInt + } } case class I32V(i: Int) extends Num case class I64V(i: Long) extends Num diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index d8a12839..fa7e8c65 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -10,11 +10,11 @@ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { def testFile(filename: String, main: Option[String] = None) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = PartialEvaluator(moduleInst, None) + val code = PartialEvaluator(moduleInst, main) println(code) } - test("push-drop") { + test("scratch") { testFile("./benchmarks/wasm/staged/scratch.wat") } } From 9947becca7cfea42e597342fdfbe25ae90ef192d Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 29 Apr 2025 00:52:36 +0800 Subject: [PATCH 21/40] fix: Frame creation is not optimizable --- src/main/scala/wasm/StagedMiniWasm.scala | 53 ++++++++++++++++---- src/test/scala/genwasym/TestStagedEval.scala | 6 ++- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 208f8b95..e897cd57 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -48,7 +48,7 @@ trait StagedWasmEvaluator extends SAIOps { frame(i) = v eval(rest, newStack, frame, kont, trail) case LocalTee(i) => - val (v, _) = (stack.head, stack.tail) + val v = stack.head frame(i) = v eval(rest, stack, frame, kont, trail) case GlobalGet(i) => @@ -111,11 +111,19 @@ trait StagedWasmEvaluator extends SAIOps { eval(els, inputs, frame, restK, restK :: trail) } case Br(label) => + info(s"Jump to $label") trail(label)(stack) case BrIf(label) => val (cond, newStack) = (stack.head, stack.tail) - if (cond != Values.I32(0)) trail(label)(newStack) - else eval(rest, newStack, frame, kont, trail) + if (cond != Values.I32(0)) { + info("The br_if's condition is ", cond) + info(s"Jump to $label") + trail(label)(newStack) + } else { + info("The br_if's condition is ",cond) + info(s"Continue") + eval(rest, newStack, frame, kont, trail) + } case BrTable(labels, default) => val (cond, newStack) = (stack.head, stack.tail) if (cond.toInt < labels.length) { @@ -147,12 +155,14 @@ trait StagedWasmEvaluator extends SAIOps { val newStack = stack.drop(ty.inps.size) val newFrame = frameOf(ty.inps.size + locals.size) newFrame.putAll(args) + info("New frame:", newFrame) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) } else { val callee = fun( (stack: Rep[Stack], frame: Rep[Frame], kont: Rep[Cont[Unit]]) => { + info(s"Entered the function at $funcIndex, stack =", stack, ", frame =", frame) eval(body, stack, frame, kont, kont::Nil):Rep[Unit] } ) @@ -223,7 +233,7 @@ trait StagedWasmEvaluator extends SAIOps { case Some(func_name) => module.defs.flatMap({ case Export(`func_name`, ExportFunc(fid)) => - println(s"Entering function $main") + Predef.println(s"Now compiling start with function $main") module.funcs(fid) match { case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => Some(body) case _ => throw new Exception("Entry function has no concrete body") @@ -247,8 +257,12 @@ trait StagedWasmEvaluator extends SAIOps { eval(instrs, Stack.emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error } - def evalTop(main: Option[String]): Rep[Unit] = { + def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { val haltK: Rep[Stack] => Rep[Unit] = stack => { + if (printRes) { + print("Final stack: ") + println(stack) + } "no-op".reflectCtrlWith[Unit]() } evalTop(fun(haltK), main) @@ -266,6 +280,10 @@ trait StagedWasmEvaluator extends SAIOps { "unreachable".reflectCtrlWith() } + def info(xs: Rep[_]*): Rep[Unit] = { + "info".reflectCtrlWith(xs: _*) + } + // runtime values object Values { def lift(num: Num): Rep[Num] = { @@ -334,7 +352,7 @@ trait StagedWasmEvaluator extends SAIOps { // frame creation and operations def frameOf(size: Int): Rep[Frame] = { - "frame-of".reflectWith(size) + "frame-of".reflectCtrlWith(size) } implicit class FrameOps(frame: Rep[Frame]) { @@ -408,6 +426,11 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case Node(_, "global-set", List(i, value), _) => shallow(i); emit(".globalSet("); shallow(value); emit(")") case _ => super.traverse(n) + case Node(_, "info", xs, _) => + emit("println("); xs.foreach { x => + shallow(x); emit(", ") + }; emit(")") + } // code generation for pure nodes @@ -525,13 +548,19 @@ object Prelude { class Frame(val size: Int) { private val data = new Array[Num](size) - def apply(i: Int): Num = data(i) + def apply(i: Int): Num = { + info(s"frame(${i}) = ${data(i)}") + data(i) + } def update(i: Int, v: Num): Unit = data(i) = v def putAll(xs: List[Num]): Unit = { for (i <- 0 until xs.size) { data(i) = xs(i) } } + override def toString: String = { + "Frame(" + data.mkString(", ") + ")" + } } object Global { @@ -540,6 +569,12 @@ object Prelude { def globalGet(i: Int): Num = globals(i) def globalSet(i: Int, v: Num): Unit = globals(i) = v } + + def info(xs: Any*): Unit = { + if (System.getenv("DEBUG") != null) { + println("[INFO] " + xs.mkString(" ")) + } + } } import Prelude._ @@ -553,12 +588,12 @@ object Main { } object PartialEvaluator { - def apply(moduleInst: ModuleInstance, main: Option[String]): String = { + def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") val code = new WasmCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst def snippet(x: Rep[Unit]): Rep[Unit] = { - evalTop(main) + evalTop(main, printRes) } } code.code diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index fa7e8c65..2d9b3693 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -8,13 +8,15 @@ import gensym.wasm.parser._ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { - def testFile(filename: String, main: Option[String] = None) = { + def testFile(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = PartialEvaluator(moduleInst, main) + val code = PartialEvaluator(moduleInst, main, true) println(code) } test("scratch") { testFile("./benchmarks/wasm/staged/scratch.wat") } + + test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } } From e7da82323c419895b12bcef6e09cd9e23dd0ff13 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 29 Apr 2025 11:02:14 +0800 Subject: [PATCH 22/40] demo br_table's attempts --- benchmarks/wasm/staged/scratch.wat | 2 +- src/main/scala/wasm/StagedMiniWasm.scala | 29 ++++++++++++++++++++---- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/benchmarks/wasm/staged/scratch.wat b/benchmarks/wasm/staged/scratch.wat index 2b0b3ede..6b0a4c44 100644 --- a/benchmarks/wasm/staged/scratch.wat +++ b/benchmarks/wasm/staged/scratch.wat @@ -28,7 +28,7 @@ (block i32.const 4 i32.const 2 - ;; br_table 0 0 ;; the compilation of br_table is problematic now + br_table 0 1 0 ;; the compilation of br_table is problematic now ) ) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index e897cd57..53eca440 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -126,10 +126,31 @@ trait StagedWasmEvaluator extends SAIOps { } case BrTable(labels, default) => val (cond, newStack) = (stack.head, stack.tail) - if (cond.toInt < labels.length) { - var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) - val goto: Rep[Cont[Unit]] = targets(cond.toInt) - goto(newStack) // TODO: this line will trigger an exception + if (cond.toInt < unit(labels.length)) { + // Implementation 1(trigger runtime exception): + // var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) + // val goto: Rep[Cont[Unit]] = targets(cond.toInt) + // goto(newStack) // TODO: confirm why this line will trigger an exception + + // Implementation 2(if-expression is not generated at all): + // var goto: Rep[Cont[Unit]] = null + // for (i <- Range(0, labels.length)) { + // if (i != cond.toInt) { + // info(s"Jump(br_table) to ${labels(i)}") + // return trail(labels(i))(newStack) + // } + // } + + // Implementation 3(assignment to `goto` is not generated): + var goto: Rep[Cont[Unit]] = null + for (i <- Range(0, labels.length)) { + if (i != cond.toInt) { + info(s"Jump(br_table) to ${labels(i)}") + goto = trail(labels(i)) + } + } + info(s"Jump to goto target") + goto(newStack) } else { trail(default)(newStack) } From 2de28f5a336e87c2b04a97e56698e99861ff7447 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 29 Apr 2025 11:30:00 +0800 Subject: [PATCH 23/40] fix: tail call --- src/main/scala/wasm/MiniWasm.scala | 4 ++-- src/main/scala/wasm/StagedMiniWasm.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 2a5abe6d..84a8bd88 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -255,8 +255,8 @@ case class Evaluator(module: ModuleInstance) { val frameLocals = args ++ locals.map(zero(_)) val newFrame = Frame(ArrayBuffer(frameLocals: _*)) if (isTail) - // when tail call, share the continuation for returning with the callee - eval(body, List(), newFrame, kont, List(kont)) + // when tail call, return to the caller's return continuation + eval(body, List(), newFrame, trail.last, List(trail.last)) else { val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 53eca440..be9526c1 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -191,8 +191,8 @@ trait StagedWasmEvaluator extends SAIOps { callee } if (isTail) - // when tail call, share the continuation for returning with the callee - callee(Stack.emptyStack, newFrame, kont) + // when tail call, return to the caller's return continuation + callee(Stack.emptyStack, newFrame, trail.last) else { val restK = fun( (retStack: Rep[Stack]) => From b5a69dca1314b14af65b0d8c5cea60747a37746b Mon Sep 17 00:00:00 2001 From: ahuoguo Date: Tue, 29 Apr 2025 15:13:20 +0200 Subject: [PATCH 24/40] fix global --- benchmarks/wasm/global.wat | 19 +++++++++++++++++++ benchmarks/wasm/staged/scratch.wat | 4 ++-- src/main/scala/wasm/Parser.scala | 2 +- src/test/scala/genwasym/TestEval.scala | 3 +++ 4 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 benchmarks/wasm/global.wat diff --git a/benchmarks/wasm/global.wat b/benchmarks/wasm/global.wat new file mode 100644 index 00000000..236467ef --- /dev/null +++ b/benchmarks/wasm/global.wat @@ -0,0 +1,19 @@ +(module + (type (;0;) (func (result i32))) + (type (;1;) (func)) + + (func (;0;) (type 0) (result i32) + i32.const 42 + global.set 0 + global.get 0 + ) + (func (;1;) (type 1) + call 0 + ;; should be 42 + ;; drop + ) + (start 1) + (memory (;0;) 2) + (export "main" (func 1)) + (global (;0;) (mut i32) (i32.const 0)) +) \ No newline at end of file diff --git a/benchmarks/wasm/staged/scratch.wat b/benchmarks/wasm/staged/scratch.wat index 6b0a4c44..b725d770 100644 --- a/benchmarks/wasm/staged/scratch.wat +++ b/benchmarks/wasm/staged/scratch.wat @@ -15,9 +15,9 @@ i32.add nop (call 1) - global.get 1 + global.get 0 i32.const 3 - global.set 2 ;; TODO: this line was compiled to global.get, fix the parser! + global.set 0 if (result i32) ;; label = @1 i32.const 1 diff --git a/src/main/scala/wasm/Parser.scala b/src/main/scala/wasm/Parser.scala index 40b497e0..0ce9fa94 100644 --- a/src/main/scala/wasm/Parser.scala +++ b/src/main/scala/wasm/Parser.scala @@ -314,7 +314,7 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { else if (ctx.LOCAL_GET() != null) LocalGet(getVar(ctx.idx(0)).toInt) else if (ctx.LOCAL_SET() != null) LocalSet(getVar(ctx.idx(0)).toInt) else if (ctx.LOCAL_TEE() != null) LocalTee(getVar(ctx.idx(0)).toInt) - else if (ctx.GLOBAL_SET() != null) GlobalGet(getVar(ctx.idx(0)).toInt) + else if (ctx.GLOBAL_SET() != null) GlobalSet(getVar(ctx.idx(0)).toInt) else if (ctx.GLOBAL_GET() != null) GlobalGet(getVar(ctx.idx(0)).toInt) else if (ctx.load() != null) { val ty = visitNumType(ctx.load.numType) diff --git a/src/test/scala/genwasym/TestEval.scala b/src/test/scala/genwasym/TestEval.scala index 38453996..2e358375 100644 --- a/src/test/scala/genwasym/TestEval.scala +++ b/src/test/scala/genwasym/TestEval.scala @@ -81,6 +81,9 @@ class TestEval extends FunSuite { test("loop block - poly br") { testFile("./benchmarks/wasm/loop_poly.wat", None, ExpStack(List(2, 1))) } + test("global") { + testFile("./benchmarks/wasm/global.wat", None, ExpInt(42)) + } // just a test for .bin.wast utility // the complete tests can be seen at https://github.com/Generative-Program-Analysis/wasm-cps/ From de8f18e2c51a6e1f4be5c9b0e1bd2bb19d3280b3 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 29 Apr 2025 21:20:47 +0800 Subject: [PATCH 25/40] fix: code generation for global.set --- src/main/scala/wasm/StagedMiniWasm.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index be9526c1..3cb0ddd7 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -445,7 +445,7 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { // TODO: what is the protocol of automatic new line insertion? shallow(frame); emit(".update("); shallow(i); emit(", "); shallow(value); emit(")\n") case Node(_, "global-set", List(i, value), _) => - shallow(i); emit(".globalSet("); shallow(value); emit(")") + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") case _ => super.traverse(n) case Node(_, "info", xs, _) => emit("println("); xs.foreach { x => From 3bbd27e408323f66a6bd7c8752dbb7c1418bd749 Mon Sep 17 00:00:00 2001 From: Guannan Wei Date: Tue, 29 Apr 2025 16:16:10 +0200 Subject: [PATCH 26/40] brtable seems to work, but there is code duplication problem --- benchmarks/wasm/staged/brtable.wat | 11 +++++++++++ src/main/scala/wasm/StagedMiniWasm.scala | 13 +++++++++++-- src/test/scala/genwasym/TestStagedEval.scala | 6 ++++++ third-party/lms-clean | 2 +- 4 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 benchmarks/wasm/staged/brtable.wat diff --git a/benchmarks/wasm/staged/brtable.wat b/benchmarks/wasm/staged/brtable.wat new file mode 100644 index 00000000..91133d70 --- /dev/null +++ b/benchmarks/wasm/staged/brtable.wat @@ -0,0 +1,11 @@ +(module $push-drop + (global (;0;) (mut i32) (i32.const 1048576)) + (func (;0;) (type 1) (result i32) + i32.const 2 + (block + (block + br_table 0 1 0 + ) + ) + ) + (start 0)) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 3cb0ddd7..bff5cca5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -101,8 +101,7 @@ trait StagedWasmEvaluator extends SAIOps { val (cond, newStack) = (stack.head, stack.tail) val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) // TODO: can we avoid code duplication here? - val restK = fun( - (retStack: Rep[Stack]) => + val restK = fun((retStack: Rep[Stack]) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) ) if (cond != Values.I32(0)) { @@ -126,6 +125,15 @@ trait StagedWasmEvaluator extends SAIOps { } case BrTable(labels, default) => val (cond, newStack) = (stack.head, stack.tail) + def aux(choices: List[Int], idx: Int): Rep[Unit] = { + if (choices.isEmpty) trail(default)(newStack) + else { + if (cond.toInt == idx) trail(choices.head)(newStack) + else aux(choices.tail, idx + 1) + } + } + aux(labels, 0) + /* if (cond.toInt < unit(labels.length)) { // Implementation 1(trigger runtime exception): // var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) @@ -154,6 +162,7 @@ trait StagedWasmEvaluator extends SAIOps { } else { trail(default)(newStack) } + */ case Return => trail.last(stack) case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 2d9b3693..b572f90f 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -14,9 +14,15 @@ class TestStagedEval extends FunSuite { println(code) } + /* test("scratch") { testFile("./benchmarks/wasm/staged/scratch.wat") } test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + */ + + test("brtable") { + testFile("./benchmarks/wasm/staged/brtable.wat") + } } diff --git a/third-party/lms-clean b/third-party/lms-clean index b6f019ae..f3338d3a 160000 --- a/third-party/lms-clean +++ b/third-party/lms-clean @@ -1 +1 @@ -Subproject commit b6f019aef1df5f1f12bcd0b43a9136d7f9ce7704 +Subproject commit f3338d3ab0ea74e90e44acfdbbda7912c249a7fc From a83eb06f80e38e95a5a13722caceafbdc58f6e2e Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 5 May 2025 02:21:02 +0800 Subject: [PATCH 27/40] effectful staged interpreter --- src/main/scala/wasm/StagedMiniWasm.scala | 473 +++++++++++++---------- 1 file changed, 261 insertions(+), 212 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index bff5cca5..e04fe744 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -21,202 +21,186 @@ trait StagedWasmEvaluator extends SAIOps { // Adapter.resetState // Adapter.g = Adapter.mkGraphBuilder - trait Stack - type Cont[A] = Stack => A - type Trail[A] = List[Rep[Cont[A]]] + trait Slice trait Frame + type Cont[A] = Unit => A + type Trail[A] = List[Rep[Cont[A]]] + // a cache storing the compiled code for each function, to reduce re-compilation - val compileCache = new HashMap[Int, Rep[(Stack, Frame, Cont[Unit]) => Unit]] + val compileCache = new HashMap[Int, Rep[(Cont[Unit]) => Unit]] // NOTE: We don't support Ans type polymorphism yet def eval(insts: List[Instr], - stack: Rep[Stack], - frame: Rep[Frame], kont: Rep[Cont[Unit]], trail: Trail[Unit]): Rep[Unit] = { - if (insts.isEmpty) return kont(stack) + if (insts.isEmpty) return kont() val (inst, rest) = (insts.head, insts.tail) inst match { - case Drop => eval(rest, stack.tail, frame, kont, trail) - case WasmConst(num) => eval(rest, Values.lift(num) :: stack, frame, kont, trail) + case Drop => + Stack.pop() + eval(rest, kont, trail) + case WasmConst(num) => + Stack.push(num) + eval(rest, kont, trail) case LocalGet(i) => - eval(rest, frame.get(i) :: stack, frame, kont, trail) + Stack.push(Frames.get(i)) + eval(rest, kont, trail) case LocalSet(i) => - val (v, newStack) = (stack.head, stack.tail) - frame(i) = v - eval(rest, newStack, frame, kont, trail) + Frames.set(i, Stack.pop()) + eval(rest, kont, trail) case LocalTee(i) => - val v = stack.head - frame(i) = v - eval(rest, stack, frame, kont, trail) + Frames.set(i, Stack.peek) + eval(rest, kont, trail) case GlobalGet(i) => - eval(rest, Global.globalGet(i) :: stack, frame, kont, trail) + Stack.push(Global.globalGet(i)) + eval(rest, kont, trail) case GlobalSet(i) => - val (value, newStack) = (stack.head, stack.tail) + val value = Stack.pop() module.globals(i).ty match { case GlobalType(tipe, true) => Global.globalSet(i, value) case _ => throw new Exception("Cannot set immutable global") } - eval(rest, newStack, frame, kont, trail) + eval(rest, kont, trail) case MemorySize => ??? case MemoryGrow => ??? case MemoryFill => ??? case Nop => - eval(rest, stack, frame, kont, trail) + eval(rest, kont, trail) case Unreachable => unreachable() case Test(op) => - val (v, newStack) = (stack.head, stack.tail) - eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail) + val v = Stack.pop() + Stack.push(evalTestOp(op, v)) + eval(rest, kont, trail) case Unary(op) => - val (v, newStack) = (stack.head, stack.tail) - eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail) + val v = Stack.pop() + Stack.push(evalUnaryOp(op, v)) + eval(rest, kont, trail) case Binary(op) => - val (v2, v1, newStack) = (stack.head, stack.tail.head, stack.tail.tail) - eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail) + val v2 = Stack.pop() + val v1 = Stack.pop() + Stack.push(evalBinOp(op, v1, v2)) + eval(rest, kont, trail) case Compare(op) => - val (v2, v1, newStack) = (stack.head, stack.tail.head, stack.tail.tail) - eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail) + val v2 = Stack.pop() + val v1 = Stack.pop() + Stack.push(evalRelOp(op, v1, v2)) + eval(rest, kont, trail) case WasmBlock(ty, inner) => + // no need to modify the stack when entering a block + // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType - val (inputs, restStack) = stack.splitAt(funcTy.inps.size) - val restK = fun( - (retStack: Rep[Stack]) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) - ) - eval(inner, inputs, frame, restK, restK :: trail) + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) + eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType - val (inputs, restStack) = stack.splitAt(funcTy.inps.size) - val restK = fun( - (retStack: Rep[Stack]) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) - ) - def loop(retStack: Rep[Stack]): Rep[Unit] = - eval(inner, retStack.take(funcTy.inps.size), frame, restK, fun(loop _) :: trail) - loop(inputs) + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val restK = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) + def loop(_u: Rep[Unit]): Rep[Unit] = + eval(inner, restK, fun(loop _) :: trail) + loop(()) case If(ty, thn, els) => val funcTy = ty.funcType - val (cond, newStack) = (stack.head, stack.tail) - val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val cond = Stack.pop() // TODO: can we avoid code duplication here? - val restK = fun((retStack: Rep[Stack]) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) - ) + val restK = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) if (cond != Values.I32(0)) { - eval(thn, inputs, frame, restK, restK :: trail) + eval(thn, restK, restK :: trail) } else { - eval(els, inputs, frame, restK, restK :: trail) + eval(els, restK, restK :: trail) } case Br(label) => info(s"Jump to $label") - trail(label)(stack) + trail(label)(()) case BrIf(label) => - val (cond, newStack) = (stack.head, stack.tail) + val cond = Stack.pop() + info(s"The br_if(${label})'s condition is ", cond) if (cond != Values.I32(0)) { - info("The br_if's condition is ", cond) info(s"Jump to $label") - trail(label)(newStack) + trail(label)(()) } else { - info("The br_if's condition is ",cond) info(s"Continue") - eval(rest, newStack, frame, kont, trail) + eval(rest, kont, trail) } case BrTable(labels, default) => - val (cond, newStack) = (stack.head, stack.tail) + val cond = Stack.pop() def aux(choices: List[Int], idx: Int): Rep[Unit] = { - if (choices.isEmpty) trail(default)(newStack) + if (choices.isEmpty) trail(default)(()) else { - if (cond.toInt == idx) trail(choices.head)(newStack) + if (cond.toInt == idx) trail(choices.head)(()) else aux(choices.tail, idx + 1) } } aux(labels, 0) - /* - if (cond.toInt < unit(labels.length)) { - // Implementation 1(trigger runtime exception): - // var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) - // val goto: Rep[Cont[Unit]] = targets(cond.toInt) - // goto(newStack) // TODO: confirm why this line will trigger an exception - - // Implementation 2(if-expression is not generated at all): - // var goto: Rep[Cont[Unit]] = null - // for (i <- Range(0, labels.length)) { - // if (i != cond.toInt) { - // info(s"Jump(br_table) to ${labels(i)}") - // return trail(labels(i))(newStack) - // } - // } - - // Implementation 3(assignment to `goto` is not generated): - var goto: Rep[Cont[Unit]] = null - for (i <- Range(0, labels.length)) { - if (i != cond.toInt) { - info(s"Jump(br_table) to ${labels(i)}") - goto = trail(labels(i)) - } - } - info(s"Jump to goto target") - goto(newStack) - } else { - trail(default)(newStack) - } - */ - case Return => trail.last(stack) - case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) - case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) + case Return => trail.last(()) + case Call(f) => evalCall(rest, kont, trail, f, false) + case ReturnCall(f) => evalCall(rest, kont, trail, f, true) case _ => val todo = "todo-op".reflectCtrlWith() - eval(rest, todo :: stack, frame, kont, trail) + eval(rest, kont, trail) } } def evalCall(rest: List[Instr], - stack: Rep[Stack], - frame: Rep[Frame], kont: Rep[Cont[Unit]], trail: Trail[Unit], funcIndex: Int, isTail: Boolean): Rep[Unit] = { module.funcs(funcIndex) match { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => - val args = stack.take(ty.inps.size).reverse - val newStack = stack.drop(ty.inps.size) - val newFrame = frameOf(ty.inps.size + locals.size) - newFrame.putAll(args) - info("New frame:", newFrame) + val returnSize = Stack.size - ty.inps.size + ty.out.size + val args = Stack.take(ty.inps.size) + info("New frame:", Frames.top) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) } else { - val callee = fun( - (stack: Rep[Stack], frame: Rep[Frame], kont: Rep[Cont[Unit]]) => { - info(s"Entered the function at $funcIndex, stack =", stack, ", frame =", frame) - eval(body, stack, frame, kont, kont::Nil):Rep[Unit] + val callee = topFun( + (kont: Rep[Cont[Unit]]) => { + info(s"Entered the function at $funcIndex, stackSize =", Stack.size, ", frame =", Frames.top) + eval(body, kont, kont::Nil): Rep[Unit] } ) compileCache(funcIndex) = callee callee } - if (isTail) + val frameSize = ty.inps.size + locals.size + if (isTail) { // when tail call, return to the caller's return continuation - callee(Stack.emptyStack, newFrame, trail.last) - else { - val restK = fun( - (retStack: Rep[Stack]) => - eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail) - ) + Frames.popFrame() + Frames.pushFrame(frameSize) + Frames.putAll(args) + callee(trail.last) + } else { + val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + Stack.reset(returnSize) + Frames.popFrame() + eval(rest, kont, trail) + }) // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) - callee(Stack.emptyStack, newFrame, restK) + Frames.pushFrame(frameSize) + Frames.putAll(args) + callee(restK) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => //println(s"[DEBUG] current stack: $stack") - val (v, newStack) = (stack.head, stack.tail) + val v = Stack.pop() println(v) - eval(rest, newStack, frame, kont, trail) + eval(rest, kont, trail) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } @@ -283,122 +267,125 @@ trait StagedWasmEvaluator extends SAIOps { } } val (instrs, localSize) = (funBody.body, funBody.locals.size) - val frame = frameOf(localSize) - eval(instrs, Stack.emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error + Stack.initialize() + Frames.pushFrame(localSize) + eval(instrs, kont, kont::Nil) + Frames.popFrame() } def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { - val haltK: Rep[Stack] => Rep[Unit] = stack => { + val haltK: Rep[Unit] => Rep[Unit] = (_) => { if (printRes) { - print("Final stack: ") - println(stack) + Stack.print() } "no-op".reflectCtrlWith[Unit]() } - evalTop(fun(haltK), main) + val temp: Rep[Cont[Unit]] = fun(haltK) + evalTop(temp, main) } // stack creation and operations object Stack { - def emptyStack: Rep[Stack] = { - "empty-stack".reflectWith() + def initialize(): Rep[Unit] = { + "stack-init".reflectCtrlWith() } - } - // call unreachable - def unreachable(): Rep[Unit] = { - "unreachable".reflectCtrlWith() - } - - def info(xs: Rep[_]*): Rep[Unit] = { - "info".reflectCtrlWith(xs: _*) - } + def pop(): Rep[Num] = { + "stack-pop".reflectCtrlWith() + } - // runtime values - object Values { - def lift(num: Num): Rep[Num] = { - num match { - case I32V(i) => I32(i) - case I64V(i) => I64(i) - } + def peek: Rep[Num] = { + "stack-peek".reflectCtrlWith() } - def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectWith(i) + def push(v: Rep[Num]): Rep[Unit] = { + "stack-push".reflectCtrlWith(v) } - def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectWith(i) + def drop(n: Int): Rep[Unit] = { + "stack-drop".reflectCtrlWith(n) } - } - // global read/write - object Global{ - def globalGet(i: Int): Rep[Num] = { - "global-get".reflectWith(i) + def print(): Rep[Unit] = { + "stack-print".reflectCtrlWith() } - def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { - "global-set".reflectCtrlWith(i, value) + def size: Rep[Int] = { + "stack-size".reflectCtrlWith() } - } - // TODO: The stack should be allocated on the stack to get optimal performance - implicit class StackOps(stack: Rep[Stack]) { - def head: Rep[Num] = { - "stack-head".reflectCtrlWith(stack) + def reset(x: Rep[Int]): Rep[Unit] = { + "stack-reset".reflectCtrlWith(x) } - def tail: Rep[Stack] = { - "stack-tail".reflectCtrlWith(stack) + def take(n: Int): Rep[Slice] = { + "stack-take".reflectCtrlWith(n) } + } - def ::[A](v: Rep[Num]): Rep[Stack] = { - "stack-cons".reflectCtrlWith(v, stack) + object Frames { + def get(i: Int): Rep[Num] = { + "frame-get".reflectCtrlWith(i) } - def ++(v: Rep[Stack]): Rep[Stack] = { - "stack-append".reflectCtrlWith(stack, v) + def set(i: Int, v: Rep[Num]): Rep[Unit] = { + "frame-set".reflectCtrlWith(i, v) } - def take(n: Int): Rep[Stack] = { - if (n == 0) Stack.emptyStack - else "stack-take".reflectWith(stack, n) + def pushFrame(i: Int): Rep[Unit] = { + "frame-push".reflectCtrlWith(i) } - def drop(n: Int): Rep[Stack] = { - if (n == 0) stack - else "stack-drop".reflectWith(stack, n) + def popFrame(): Rep[Unit] = { + "frame-pop".reflectCtrlWith() } - def reverse: Rep[Stack] = { - "stack-reverse".reflectWith(stack) + def putAll(args: Rep[Slice]): Rep[Unit] = { + "frame-putAll".reflectCtrlWith(args) } - def splitAt(n: Int): (Rep[Stack], Rep[Stack]) = { - (take(n), drop(n)) + def top: Rep[Frame] = { + "frame-top".reflectCtrlWith() } } - // frame creation and operations - def frameOf(size: Int): Rep[Frame] = { - "frame-of".reflectCtrlWith(size) + + // call unreachable + def unreachable(): Rep[Unit] = { + "unreachable".reflectCtrlWith() } - implicit class FrameOps(frame: Rep[Frame]) { + def info(xs: Rep[_]*): Rep[Unit] = { + "info".reflectCtrlWith(xs: _*) + } - def get(i: Int): Rep[Num] = { - "frame-get".reflectCtrlWith(frame, i) + // runtime values + object Values { + def lift(num: Num): Rep[Num] = { + num match { + case I32V(i) => I32(i) + case I64V(i) => I64(i) + } + } + + def I32(i: Rep[Int]): Rep[Num] = { + "I32V".reflectWith(i) } - def putAll(args: Rep[Stack]) = { - "frame-putAll".reflectCtrlWith(frame, args) + def I64(i: Rep[Long]): Rep[Num] = { + "I64V".reflectWith(i) } + } - def update(i: Int, value: Rep[Num]) = { - "frame-update".reflectCtrlWith(frame, i, value) + // global read/write + object Global{ + def globalGet(i: Int): Rep[Num] = { + "global-get".reflectWith(i) } + def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { + "global-set".reflectCtrlWith(i, value) + } } // runtime Num type @@ -446,49 +433,56 @@ trait StagedWasmEvaluator extends SAIOps { def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith(num, rhs) } + implicit class SliceOps(slice: Rep[Slice]) { + def reverse: Rep[Slice] = "slice-reverse".reflectWith(slice) + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { - case Node(_, "frame-update", List(frame, i, value), _) => - // TODO: what is the protocol of automatic new line insertion? - shallow(frame); emit(".update("); shallow(i); emit(", "); shallow(value); emit(")\n") + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")\n") + case Node(_, "stack-reset", List(n), _) => + emit("Stack.reset("); shallow(n); emit(")\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize()\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print()\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(")\n") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()\n") + case Node(_, "frame-putAll", List(args), _) => + emit("Frames.putAll("); shallow(args); emit(")\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") case Node(_, "global-set", List(i, value), _) => emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") case _ => super.traverse(n) - case Node(_, "info", xs, _) => - emit("println("); xs.foreach { x => - shallow(x); emit(", ") - }; emit(")") - } // code generation for pure nodes override def shallow(n: Node): Unit = n match { - case Node(_, "stack-take", List(stack, n), _) => - shallow(stack); emit(".take("); shallow(n); emit(")") - case Node(_, "stack-drop", List(stack, n), _) => - shallow(stack); emit(".drop("); shallow(n); emit(")") - case Node(_, "stack-append", List(stack1, stack2), _) => - shallow(stack1); emit(".++("); shallow(stack2); emit(")") - case Node(_, "stack-head", List(stack), _) => - shallow(stack); emit(".head") - case Node(_, "stack-reverse", List(stack), _) => - shallow(stack); emit(".reverse") - case Node(_, "stack-cons", List(v, stack), _) => - shallow(stack); emit(".::("); shallow(v); emit(")") - case Node(_, "stack-tail", List(stack), _) => - shallow(stack); emit(".tail") - case Node(_, "empty-stack", _, _) => - emit("Nil") - case Node(_, "frame-of", List(size), _) => - emit("new Frame("); shallow(size); emit(")") - case Node(_, "frame-get", List(frame, i), _) => - shallow(frame); emit("("); shallow(i); emit(")") - case Node(_, "frame-putAll", List(frame, args), _) => - shallow(frame); emit(".putAll("); shallow(args); emit(")") + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek\n") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "slice-reverse", List(slice), _) => + shallow(slice); emit(".reverse") + case Node(_, "stack-size", _, _) => + emit("Stack.size") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "frame-top", _, _) => + emit("Frames.top") case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => @@ -573,16 +567,48 @@ object Prelude { case class I32V(i: Int) extends Num case class I64V(i: Long) extends Num +object Stack { + private val buffer = new scala.collection.mutable.ArrayBuffer[Num]() + def push(v: Num): Unit = buffer.append(v) + def pop(): Num = { + buffer.remove(buffer.size - 1) + } + def peek: Num = { + buffer.last + } + def size: Int = buffer.size + def drop(n: Int): Unit = { + buffer.remove(buffer.size - n, n) + } + def take(n: Int): List[Num] = { + val xs = buffer.takeRight(n).toList + drop(n) + xs + } + def reset(size: Int): Unit = { + info(s"Reset stack to size $size") + while (buffer.size > size) { + buffer.remove(buffer.size - 1) + } + } + def initialize(): Unit = buffer.clear() + def print(): Unit = { + println("Stack: " + buffer.mkString(", ")) + } +} - type Stack = List[Num] + type Slice = List[Num] class Frame(val size: Int) { private val data = new Array[Num](size) def apply(i: Int): Num = { - info(s"frame(${i}) = ${data(i)}") + info(s"frame(${i}) is ${data(i)}") data(i) } - def update(i: Int, v: Num): Unit = data(i) = v + def update(i: Int, v: Num): Unit = { + info(s"set frame(${i}) to ${v}") + data(i) = v + } def putAll(xs: List[Num]): Unit = { for (i <- 0 until xs.size) { data(i) = xs(i) @@ -593,6 +619,28 @@ object Prelude { } } + object Frames { + private var frames = List[Frame]() + def pushFrame(size: Int): Unit = { + frames = new Frame(size) :: frames + } + def popFrame(): Unit = { + frames = frames.tail + } + def top: Frame = frames.head + def set(i: Int, v: Num): Unit = { + top(i) = v + } + def get(i: Int): Num = { + top(i) + } + def putAll(xs: Slice) = { + for (i <- 0 until xs.size) { + top(i) = xs(i) + } + } + } + object Global { // TODO: create global with specific size private val globals = new Array[Num](10) @@ -608,6 +656,7 @@ object Prelude { } import Prelude._ + object Main { def main(args: Array[String]): Unit = { val snippet = new Snippet() From b8a9aea960c30ac2a3e3134aa8458bbabe22cdd3 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 5 May 2025 02:21:13 +0800 Subject: [PATCH 28/40] remove non-sense tests --- benchmarks/wasm/staged/scratch.wat | 45 -------------------- src/test/scala/genwasym/TestStagedEval.scala | 6 --- 2 files changed, 51 deletions(-) delete mode 100644 benchmarks/wasm/staged/scratch.wat diff --git a/benchmarks/wasm/staged/scratch.wat b/benchmarks/wasm/staged/scratch.wat deleted file mode 100644 index b725d770..00000000 --- a/benchmarks/wasm/staged/scratch.wat +++ /dev/null @@ -1,45 +0,0 @@ -;; this file contains some wasm instructions to test if the compiler works, -;; and its execution is meaningless. -(module $push-drop - (global (;0;) (mut i32) (i32.const 1048576)) - (func (;0;) (type 1) (result i32) - (local i32 i32) - i32.const 2 - i32.const 2 - local.get 0 - local.get 1 - local.set 0 - local.tee 1 - drop - drop - i32.add - nop - (call 1) - global.get 0 - i32.const 3 - global.set 0 - - if (result i32) ;; label = @1 - i32.const 1 - else - local.get 1 - end - (block - (block - i32.const 4 - i32.const 2 - br_table 0 1 0 ;; the compilation of br_table is problematic now - ) - ) - - (loop - i32.const 5 - br 0) - return - i32.const 6 - ) - (func (;1;) (type 1) (param i32 i32) (result i32) - (local i32 i32) - local.get 0 - local.get 1) - (start 0)) \ No newline at end of file diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index b572f90f..4c46fc5b 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -14,13 +14,7 @@ class TestStagedEval extends FunSuite { println(code) } - /* - test("scratch") { - testFile("./benchmarks/wasm/staged/scratch.wat") - } - test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } - */ test("brtable") { testFile("./benchmarks/wasm/staged/brtable.wat") From b7b87867a2d6e9ff5f0c029770b2898b89b6ef90 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 5 May 2025 21:57:31 +0800 Subject: [PATCH 29/40] scratch cpp backend --- src/main/scala/wasm/StagedMiniWasm.scala | 129 ++++++++++++++++++- src/test/scala/genwasym/TestStagedEval.scala | 23 +++- 2 files changed, 141 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index e04fe744..867638e2 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -5,13 +5,13 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import lms.core.stub.Adapter import lms.core.virtualize import lms.macros.SourceContext -import lms.core.stub.{Base, ScalaGenBase} +import lms.core.stub.{Base, ScalaGenBase, CGenBase} import lms.core.Backend._ import lms.core.Backend.{Block => LMSBlock} import gensym.wasm.ast._ import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} -import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase} +import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} @virtualize trait StagedWasmEvaluator extends SAIOps { @@ -472,7 +472,7 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") case Node(_, "stack-peek", _, _) => - emit("Stack.peek\n") + emit("Stack.peek") case Node(_, "stack-take", List(n), _) => emit("Stack.take("); shallow(n); emit(")") case Node(_, "slice-reverse", List(slice), _) => @@ -525,7 +525,7 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { } } -trait WasmCompilerDriver[A, B] +trait WasmToScalaCompilerDriver[A, B] extends SAIDriver[A, B] with StagedWasmEvaluator { q => override val codegen = new StagedWasmScalaGen { val IR: q.type = q @@ -533,6 +533,7 @@ trait WasmCompilerDriver[A, B] override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Stack")) "Stack" else if(m.toString.endsWith("Frame")) "Frame" + else if(m.toString.endsWith("Slice")) "Slice" else super.remap(m) } } @@ -666,10 +667,125 @@ object Main { """ } -object PartialEvaluator { + + +object WasmToScalaCompiler { + def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + println(s"Now compiling wasm module with entry function $main") + val code = new WasmToScalaCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main, printRes) + } + } + code.code + } +} + +trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + // for now, the traverse/shallow is same as the scala backend's + override def traverse(n: Node): Unit = n match { + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")\n") + case Node(_, "stack-reset", List(n), _) => + emit("Stack.reset("); shallow(n); emit(")\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize()\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print()\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(")\n") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()\n") + case Node(_, "frame-putAll", List(args), _) => + emit("Frames.putAll("); shallow(args); emit(")\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") + case Node(_, "global-set", List(i, value), _) => + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") + case _ => super.traverse(n) + } + + // code generation for pure nodes + override def shallow(n: Node): Unit = n match { + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "slice-reverse", List(slice), _) => + shallow(slice); emit(".reverse") + case Node(_, "stack-size", _, _) => + emit("Stack.size") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "frame-top", _, _) => + emit("Frames.top") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt") + case Node(_, "no-op", _, _) => + emit("()") + case _ => super.shallow(n) + } +} + + +trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmCppGen { + val IR: q.type = q + import IR._ + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Num")) "Num" + else super.remap(m) + } + } +} + +object WasmToCppCompiler { def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") - val code = new WasmCompilerDriver[Unit, Unit] { + val code = new WasmToCppCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst def snippet(x: Rep[Unit]): Rep[Unit] = { evalTop(main, printRes) @@ -678,3 +794,4 @@ object PartialEvaluator { code.code } } + diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 4c46fc5b..47afddce 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -8,15 +8,28 @@ import gensym.wasm.parser._ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { - def testFile(filename: String, main: Option[String] = None, printRes: Boolean = false) = { + def testFileToScala(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = PartialEvaluator(moduleInst, main, true) + val code = WasmToScalaCompiler(moduleInst, main, true) println(code) } - test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + test("ack-scala") { testFileToScala("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } - test("brtable") { - testFile("./benchmarks/wasm/staged/brtable.wat") + test("brtable-scala") { + testFileToScala("./benchmarks/wasm/staged/brtable.wat") } + + def testFileToCpp(filename: String, main: Option[String] = None, printRes: Boolean = false) = { + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val code = WasmToCppCompiler(moduleInst, main, true) + println(code) + } + + test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + + test("brtable-cpp") { + testFileToCpp("./benchmarks/wasm/staged/brtable.wat") + } + } From 0a8339e46165459adf3d900f3fce8bfe1f6a0caf Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 7 May 2025 01:39:17 +0800 Subject: [PATCH 30/40] some tweaks --- src/main/scala/wasm/StagedMiniWasm.scala | 31 +++++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 867638e2..aba6ec11 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -91,6 +91,7 @@ trait StagedWasmEvaluator extends SAIOps { // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType + // TODO: somehow the type of exitSize in residual program is nothing val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { Stack.reset(exitSize) @@ -686,25 +687,25 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { // for now, the traverse/shallow is same as the scala backend's override def traverse(n: Node): Unit = n match { case Node(_, "stack-push", List(value), _) => - emit("Stack.push("); shallow(value); emit(")\n") + emit("Stack.push("); shallow(value); emit(");\n") case Node(_, "stack-drop", List(n), _) => - emit("Stack.drop("); shallow(n); emit(")\n") + emit("Stack.drop("); shallow(n); emit(");\n") case Node(_, "stack-reset", List(n), _) => - emit("Stack.reset("); shallow(n); emit(")\n") + emit("Stack.reset("); shallow(n); emit(");\n") case Node(_, "stack-init", _, _) => - emit("Stack.initialize()\n") + emit("Stack.initialize();\n") case Node(_, "stack-print", _, _) => - emit("Stack.print()\n") + emit("Stack.print();\n") case Node(_, "frame-push", List(i), _) => - emit("Frames.pushFrame("); shallow(i); emit(")\n") + emit("Frames.pushFrame("); shallow(i); emit(");\n") case Node(_, "frame-pop", _, _) => - emit("Frames.popFrame()\n") + emit("Frames.popFrame();\n") case Node(_, "frame-putAll", List(args), _) => - emit("Frames.putAll("); shallow(args); emit(")\n") + emit("Frames.putAll("); shallow(args); emit(");\n") case Node(_, "frame-set", List(i, value), _) => - emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") case Node(_, "global-set", List(i, value), _) => - emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") case _ => super.traverse(n) } @@ -723,11 +724,11 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "slice-reverse", List(slice), _) => shallow(slice); emit(".reverse") case Node(_, "stack-size", _, _) => - emit("Stack.size") + emit("Stack.size()") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") case Node(_, "frame-top", _, _) => - emit("Frames.top") + emit("Frames.top()") case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => @@ -777,6 +778,12 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv import IR._ override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Num")) "Num" + else if (m.toString.endsWith("Slice")) "Slice" + else if (m.toString.endsWith("Frame")) "Frame" + else if (m.toString.endsWith("Stack")) "Stack" + else if (m.toString.endsWith("Global")) "Global" + else if (m.toString.endsWith("I32V")) "I32V" + else if (m.toString.endsWith("I64V")) "I64V" else super.remap(m) } } From b4703c7995c032dcfd7815e71dbd84c9c4c05d23 Mon Sep 17 00:00:00 2001 From: Guannan Wei Date: Mon, 12 May 2025 17:36:45 +0200 Subject: [PATCH 31/40] fix some of the nothing type --- src/main/scala/wasm/StagedMiniWasm.scala | 27 ++++++++++++------------ 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index aba6ec11..2098b6aa 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -292,7 +292,7 @@ trait StagedWasmEvaluator extends SAIOps { } def pop(): Rep[Num] = { - "stack-pop".reflectCtrlWith() + "stack-pop".reflectCtrlWith[Num]() } def peek: Rep[Num] = { @@ -312,7 +312,7 @@ trait StagedWasmEvaluator extends SAIOps { } def size: Rep[Int] = { - "stack-size".reflectCtrlWith() + "stack-size".reflectCtrlWith[Int]() } def reset(x: Rep[Int]): Rep[Unit] = { @@ -392,7 +392,7 @@ trait StagedWasmEvaluator extends SAIOps { // runtime Num type implicit class NumOps(num: Rep[Num]) { - def toInt: Rep[Int] = "num-to-int".reflectWith(num) + def toInt: Rep[Int] = "num-to-int".reflectWith[Int](num) def clz(): Rep[Num] = "unary-clz".reflectWith(num) @@ -684,6 +684,17 @@ object WasmToScalaCompiler { } trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Num")) "Num" + else if (m.toString.endsWith("Slice")) "Slice" + else if (m.toString.endsWith("Frame")) "Frame" + else if (m.toString.endsWith("Stack")) "Stack" + else if (m.toString.endsWith("Global")) "Global" + else if (m.toString.endsWith("I32V")) "I32V" + else if (m.toString.endsWith("I64V")) "I64V" + else super.remap(m) + } + // for now, the traverse/shallow is same as the scala backend's override def traverse(n: Node): Unit = n match { case Node(_, "stack-push", List(value), _) => @@ -776,16 +787,6 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv override val codegen = new StagedWasmCppGen { val IR: q.type = q import IR._ - override def remap(m: Manifest[_]): String = { - if (m.toString.endsWith("Num")) "Num" - else if (m.toString.endsWith("Slice")) "Slice" - else if (m.toString.endsWith("Frame")) "Frame" - else if (m.toString.endsWith("Stack")) "Stack" - else if (m.toString.endsWith("Global")) "Global" - else if (m.toString.endsWith("I32V")) "I32V" - else if (m.toString.endsWith("I64V")) "I64V" - else super.remap(m) - } } } From 29acef0c678f29824bf02c285c417872b2e2a3dc Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 13 May 2025 11:06:08 +0800 Subject: [PATCH 32/40] manually supply the reflect's type arguments --- src/main/scala/wasm/StagedMiniWasm.scala | 80 ++++++++++++------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 2098b6aa..c1358894 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -149,7 +149,7 @@ trait StagedWasmEvaluator extends SAIOps { case Call(f) => evalCall(rest, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, kont, trail, f, true) case _ => - val todo = "todo-op".reflectCtrlWith() + val todo = "todo-op".reflectCtrlWith[Unit]() eval(rest, kont, trail) } } @@ -288,7 +288,7 @@ trait StagedWasmEvaluator extends SAIOps { // stack creation and operations object Stack { def initialize(): Rep[Unit] = { - "stack-init".reflectCtrlWith() + "stack-init".reflectCtrlWith[Unit]() } def pop(): Rep[Num] = { @@ -296,19 +296,19 @@ trait StagedWasmEvaluator extends SAIOps { } def peek: Rep[Num] = { - "stack-peek".reflectCtrlWith() + "stack-peek".reflectCtrlWith[Num]() } def push(v: Rep[Num]): Rep[Unit] = { - "stack-push".reflectCtrlWith(v) + "stack-push".reflectCtrlWith[Unit](v) } def drop(n: Int): Rep[Unit] = { - "stack-drop".reflectCtrlWith(n) + "stack-drop".reflectCtrlWith[Unit](n) } def print(): Rep[Unit] = { - "stack-print".reflectCtrlWith() + "stack-print".reflectCtrlWith[Unit]() } def size: Rep[Int] = { @@ -316,17 +316,17 @@ trait StagedWasmEvaluator extends SAIOps { } def reset(x: Rep[Int]): Rep[Unit] = { - "stack-reset".reflectCtrlWith(x) + "stack-reset".reflectCtrlWith[Unit](x) } def take(n: Int): Rep[Slice] = { - "stack-take".reflectCtrlWith(n) + "stack-take".reflectCtrlWith[Slice](n) } } object Frames { def get(i: Int): Rep[Num] = { - "frame-get".reflectCtrlWith(i) + "frame-get".reflectCtrlWith[Num](i) } def set(i: Int, v: Rep[Num]): Rep[Unit] = { @@ -334,30 +334,30 @@ trait StagedWasmEvaluator extends SAIOps { } def pushFrame(i: Int): Rep[Unit] = { - "frame-push".reflectCtrlWith(i) + "frame-push".reflectCtrlWith[Unit](i) } def popFrame(): Rep[Unit] = { - "frame-pop".reflectCtrlWith() + "frame-pop".reflectCtrlWith[Unit]() } def putAll(args: Rep[Slice]): Rep[Unit] = { - "frame-putAll".reflectCtrlWith(args) + "frame-putAll".reflectCtrlWith[Unit](args) } def top: Rep[Frame] = { - "frame-top".reflectCtrlWith() + "frame-top".reflectCtrlWith[Frame]() } } // call unreachable def unreachable(): Rep[Unit] = { - "unreachable".reflectCtrlWith() + "unreachable".reflectCtrlWith[Unit]() } def info(xs: Rep[_]*): Rep[Unit] = { - "info".reflectCtrlWith(xs: _*) + "info".reflectCtrlWith[Unit](xs: _*) } // runtime values @@ -370,22 +370,22 @@ trait StagedWasmEvaluator extends SAIOps { } def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectWith(i) + "I32V".reflectWith[Num](i) } def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectWith(i) + "I64V".reflectWith[Num](i) } } // global read/write object Global{ def globalGet(i: Int): Rep[Num] = { - "global-get".reflectWith(i) + "global-get".reflectWith[Num](i) } def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { - "global-set".reflectCtrlWith(i, value) + "global-set".reflectCtrlWith[Unit](i, value) } } @@ -394,48 +394,48 @@ trait StagedWasmEvaluator extends SAIOps { def toInt: Rep[Int] = "num-to-int".reflectWith[Int](num) - def clz(): Rep[Num] = "unary-clz".reflectWith(num) + def clz(): Rep[Num] = "unary-clz".reflectWith[Num](num) - def ctz(): Rep[Num] = "unary-ctz".reflectWith(num) + def ctz(): Rep[Num] = "unary-ctz".reflectWith[Num](num) - def popcnt(): Rep[Num] = "unary-popcnt".reflectWith(num) + def popcnt(): Rep[Num] = "unary-popcnt".reflectWith[Num](num) - def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith(num, rhs) + def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith[Num](num, rhs) - def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith(num, rhs) + def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith[Num](num, rhs) - def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith(num, rhs) + def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith[Num](num, rhs) - def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith(num, rhs) + def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith[Num](num, rhs) - def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith(num, rhs) + def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith[Num](num, rhs) - def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith(num, rhs) + def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith[Num](num, rhs) - def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith(num, rhs) + def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith[Num](num, rhs) - def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith(num, rhs) + def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith[Num](num, rhs) - def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith(num, rhs) + def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith[Num](num, rhs) - def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith(num, rhs) + def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith[Num](num, rhs) - def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith(num, rhs) + def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith[Num](num, rhs) - def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith(num, rhs) + def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith[Num](num, rhs) - def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith(num, rhs) + def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith[Num](num, rhs) - def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith(num, rhs) + def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith[Num](num, rhs) - def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith(num, rhs) + def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith[Num](num, rhs) - def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith(num, rhs) + def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith[Num](num, rhs) - def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith(num, rhs) + def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith[Num](num, rhs) } implicit class SliceOps(slice: Rep[Slice]) { - def reverse: Rep[Slice] = "slice-reverse".reflectWith(slice) + def reverse: Rep[Slice] = "slice-reverse".reflectWith[Slice](slice) } } From 67b077bb02c903bd2d069cc0b81c8f777c744fc2 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 14 May 2025 21:27:55 +0800 Subject: [PATCH 33/40] lift every function to top level & avoid lms's common subexpr elimination --- src/main/scala/wasm/StagedMiniWasm.scala | 90 ++++++++++++++---------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index c1358894..dfbc4a7f 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -16,10 +16,6 @@ import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, @virtualize trait StagedWasmEvaluator extends SAIOps { def module: ModuleInstance - // NOTE: we don't need the following statements anymore, but where are they initialized? - // reset and initialize the internal state of Adapter - // Adapter.resetState - // Adapter.g = Adapter.mkGraphBuilder trait Slice @@ -93,7 +89,7 @@ trait StagedWasmEvaluator extends SAIOps { val funcTy = ty.funcType // TODO: somehow the type of exitSize in residual program is nothing val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + def restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => { Stack.reset(exitSize) eval(rest, kont, trail) }) @@ -101,19 +97,20 @@ trait StagedWasmEvaluator extends SAIOps { case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - val restK = fun((_: Rep[Unit]) => { + def restK = topFun((_: Rep[Unit]) => { Stack.reset(exitSize) eval(rest, kont, trail) }) - def loop(_u: Rep[Unit]): Rep[Unit] = - eval(inner, restK, fun(loop _) :: trail) + def loop : Rep[Unit => Unit] = topFun((_u: Rep[Unit]) => { + eval(inner, restK, loop :: trail) + }) loop(()) case If(ty, thn, els) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val cond = Stack.pop() // TODO: can we avoid code duplication here? - val restK = fun((_: Rep[Unit]) => { + def restK = topFun((_: Rep[Unit]) => { Stack.reset(exitSize) eval(rest, kont, trail) }) @@ -185,7 +182,7 @@ trait StagedWasmEvaluator extends SAIOps { Frames.putAll(args) callee(trail.last) } else { - val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + val restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => { Stack.reset(returnSize) Frames.popFrame() eval(rest, kont, trail) @@ -281,7 +278,7 @@ trait StagedWasmEvaluator extends SAIOps { } "no-op".reflectCtrlWith[Unit]() } - val temp: Rep[Cont[Unit]] = fun(haltK) + val temp: Rep[Cont[Unit]] = topFun(haltK) evalTop(temp, main) } @@ -370,18 +367,18 @@ trait StagedWasmEvaluator extends SAIOps { } def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectWith[Num](i) + "I32V".reflectCtrlWith[Num](i) } def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectWith[Num](i) + "I64V".reflectCtrlWith[Num](i) } } // global read/write object Global{ def globalGet(i: Int): Rep[Num] = { - "global-get".reflectWith[Num](i) + "global-get".reflectCtrlWith[Num](i) } def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { @@ -392,50 +389,50 @@ trait StagedWasmEvaluator extends SAIOps { // runtime Num type implicit class NumOps(num: Rep[Num]) { - def toInt: Rep[Int] = "num-to-int".reflectWith[Int](num) + def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num) - def clz(): Rep[Num] = "unary-clz".reflectWith[Num](num) + def clz(): Rep[Num] = "unary-clz".reflectCtrlWith[Num](num) - def ctz(): Rep[Num] = "unary-ctz".reflectWith[Num](num) + def ctz(): Rep[Num] = "unary-ctz".reflectCtrlWith[Num](num) - def popcnt(): Rep[Num] = "unary-popcnt".reflectWith[Num](num) + def popcnt(): Rep[Num] = "unary-popcnt".reflectCtrlWith[Num](num) - def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith[Num](num, rhs) + def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectCtrlWith[Num](num, rhs) - def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith[Num](num, rhs) + def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectCtrlWith[Num](num, rhs) - def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith[Num](num, rhs) + def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectCtrlWith[Num](num, rhs) - def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith[Num](num, rhs) + def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectCtrlWith[Num](num, rhs) - def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith[Num](num, rhs) + def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectCtrlWith[Num](num, rhs) - def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith[Num](num, rhs) + def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectCtrlWith[Num](num, rhs) - def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith[Num](num, rhs) + def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectCtrlWith[Num](num, rhs) - def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith[Num](num, rhs) + def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectCtrlWith[Num](num, rhs) - def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith[Num](num, rhs) + def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectCtrlWith[Num](num, rhs) - def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith[Num](num, rhs) + def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectCtrlWith[Num](num, rhs) - def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith[Num](num, rhs) + def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectCtrlWith[Num](num, rhs) - def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith[Num](num, rhs) + def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectCtrlWith[Num](num, rhs) - def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith[Num](num, rhs) + def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectCtrlWith[Num](num, rhs) - def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith[Num](num, rhs) + def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectCtrlWith[Num](num, rhs) - def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith[Num](num, rhs) + def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectCtrlWith[Num](num, rhs) - def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith[Num](num, rhs) + def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectCtrlWith[Num](num, rhs) - def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith[Num](num, rhs) + def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectCtrlWith[Num](num, rhs) } implicit class SliceOps(slice: Rep[Slice]) { - def reverse: Rep[Slice] = "slice-reverse".reflectWith[Slice](slice) + def reverse: Rep[Slice] = "slice-reverse".reflectCtrlWith[Slice](slice) } } @@ -729,7 +726,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") case Node(_, "stack-peek", _, _) => - emit("Stack.peek") + emit("Stack.peek()") case Node(_, "stack-take", List(n), _) => emit("Stack.take("); shallow(n); emit(")") case Node(_, "slice-reverse", List(slice), _) => @@ -775,11 +772,26 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "relation-geu", List(lhs, rhs), _) => shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "num-to-int", List(num), _) => - shallow(num); emit(".toInt") + shallow(num); emit(".toInt()") case Node(_, "no-op", _, _) => - emit("()") + emit("std::monostate()") case _ => super.shallow(n) } + + override def registerTopLevelFunction(id: String, streamId: String = "general")(f: => Unit) = + if (!registeredFunctions(id)) { + //if (ongoingFun(streamId)) ??? + //ongoingFun += streamId + registeredFunctions += id + withStream(functionsStreams.getOrElseUpdate(id, { + val functionsStream = new java.io.ByteArrayOutputStream() + val functionsWriter = new java.io.PrintStream(functionsStream) + (functionsWriter, functionsStream) + })._1)(f) + //ongoingFun -= streamId + } else { + withStream(functionsStreams(id)._1)(f) + } } From 6e41521a7704cbb0b7c7b55c58f78f3a81decd45 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 14 May 2025 21:50:51 +0800 Subject: [PATCH 34/40] stack pop example --- benchmarks/wasm/staged/pop.wat | 8 ++++++++ src/main/scala/wasm/StagedMiniWasm.scala | 4 ++-- src/test/scala/genwasym/TestStagedEval.scala | 4 ++++ 3 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 benchmarks/wasm/staged/pop.wat diff --git a/benchmarks/wasm/staged/pop.wat b/benchmarks/wasm/staged/pop.wat new file mode 100644 index 00000000..691839b7 --- /dev/null +++ b/benchmarks/wasm/staged/pop.wat @@ -0,0 +1,8 @@ +(module $push-drop + (global (;0;) (mut i32) (i32.const 1048576)) + (func (;0;) (type 1) (result) + i32.const 2 + i32.const 2 + i32.add + ) + (start 0)) \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index dfbc4a7f..8fb12001 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -442,6 +442,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { emit("Stack.push("); shallow(value); emit(")\n") case Node(_, "stack-drop", List(n), _) => emit("Stack.drop("); shallow(n); emit(")\n") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()\n") case Node(_, "stack-reset", List(n), _) => emit("Stack.reset("); shallow(n); emit(")\n") case Node(_, "stack-init", _, _) => @@ -465,8 +467,6 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def shallow(n: Node): Unit = n match { case Node(_, "frame-get", List(i), _) => emit("Frames.get("); shallow(i); emit(")") - case Node(_, "stack-pop", _, _) => - emit("Stack.pop()") case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") case Node(_, "stack-peek", _, _) => diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 47afddce..c96f9b7e 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -20,6 +20,10 @@ class TestStagedEval extends FunSuite { testFileToScala("./benchmarks/wasm/staged/brtable.wat") } + test("drop-scala") { + testFileToScala("./benchmarks/wasm/staged/pop.wat") + } + def testFileToCpp(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) val code = WasmToCppCompiler(moduleInst, main, true) From 9f04722faa780ed516b78d1936bd24a2a9bab443 Mon Sep 17 00:00:00 2001 From: Guannan Wei Date: Thu, 15 May 2025 23:11:19 +0200 Subject: [PATCH 35/40] not inlining + shallow --- src/main/scala/wasm/StagedMiniWasm.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 8fb12001..d35318f5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -437,13 +437,16 @@ trait StagedWasmEvaluator extends SAIOps { } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { + override def mayInline(n: Node): Boolean = n match { + case Node(s, "stack-pop", _, _) => false + case _ => super.mayInline(n) + } + override def traverse(n: Node): Unit = n match { case Node(_, "stack-push", List(value), _) => emit("Stack.push("); shallow(value); emit(")\n") case Node(_, "stack-drop", List(n), _) => emit("Stack.drop("); shallow(n); emit(")\n") - case Node(_, "stack-pop", _, _) => - emit("Stack.pop()\n") case Node(_, "stack-reset", List(n), _) => emit("Stack.reset("); shallow(n); emit(")\n") case Node(_, "stack-init", _, _) => @@ -469,6 +472,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { emit("Frames.get("); shallow(i); emit(")") case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") case Node(_, "stack-peek", _, _) => emit("Stack.peek") case Node(_, "stack-take", List(n), _) => @@ -666,7 +671,6 @@ object Main { } - object WasmToScalaCompiler { def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") From ed9c8e42b6948ef74839ba5c204602245fb9e3f8 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 17 May 2025 11:58:36 +0800 Subject: [PATCH 36/40] an almost work runtime --- src/main/scala/wasm/StagedMiniWasm.scala | 323 ++++++++++++++++++++++- 1 file changed, 312 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index d35318f5..aecd6391 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -8,6 +8,7 @@ import lms.macros.SourceContext import lms.core.stub.{Base, ScalaGenBase, CGenBase} import lms.core.Backend._ import lms.core.Backend.{Block => LMSBlock} +import lms.core.Graph import gensym.wasm.ast._ import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} @@ -88,20 +89,17 @@ trait StagedWasmEvaluator extends SAIOps { // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType // TODO: somehow the type of exitSize in residual program is nothing - val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - def restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => { - Stack.reset(exitSize) + def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { eval(rest, kont, trail) }) eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - def restK = topFun((_: Rep[Unit]) => { - Stack.reset(exitSize) + def restK = fun((_: Rep[Unit]) => { eval(rest, kont, trail) }) - def loop : Rep[Unit => Unit] = topFun((_u: Rep[Unit]) => { + def loop : Rep[Unit => Unit] = fun((_u: Rep[Unit]) => { eval(inner, restK, loop :: trail) }) loop(()) @@ -110,8 +108,7 @@ trait StagedWasmEvaluator extends SAIOps { val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val cond = Stack.pop() // TODO: can we avoid code duplication here? - def restK = topFun((_: Rep[Unit]) => { - Stack.reset(exitSize) + def restK = fun((_: Rep[Unit]) => { eval(rest, kont, trail) }) if (cond != Values.I32(0)) { @@ -182,8 +179,7 @@ trait StagedWasmEvaluator extends SAIOps { Frames.putAll(args) callee(trail.last) } else { - val restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => { - Stack.reset(returnSize) + val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { Frames.popFrame() eval(rest, kont, trail) }) @@ -278,7 +274,7 @@ trait StagedWasmEvaluator extends SAIOps { } "no-op".reflectCtrlWith[Unit]() } - val temp: Rep[Cont[Unit]] = topFun(haltK) + val temp: Rep[Cont[Unit]] = fun(haltK) evalTop(temp, main) } @@ -796,6 +792,310 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { } else { withStream(functionsStreams(id)._1)(f) } + + override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { + val ng = init(g) + emitln(prelude) + emitln(""" + |/***************************************** + |Emitting Generated Code + |*******************************************/ + """.stripMargin) + emitln(""" +#include +#include +#include +#include +#include """) + val src = run(name, ng) + emit(src) + emitln(""" + |/***************************************** + |End of Generated Code + |*******************************************/ + |int main(int argc, char *argv[]) { + | Snippet(std::monostate{}); + | return 0; + |}""".stripMargin) + } + + val prelude = """ +#include +#include +#include +#include +#include +#include +#include +#include + +#define info(x, ...) + +class Num_t { +public: + virtual std::unique_ptr clone() const = 0; + + virtual void display() = 0; + virtual int32_t toInt() = 0; + virtual int64_t toLong() = 0; +}; + +class I32V_t : public Num_t { +public: + I32V_t(int32_t value) : value_(value) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + void display() override { std::cout << value_ << std::endl; } + + int32_t toInt() override { return value_; } + + int64_t toLong() override { return static_cast(value_); } + +private: + int32_t value_; +}; + +class I64V_t : public Num_t { +public: + I64V_t(int64_t value) : value_(value) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + void display() override { std::cout << value_ << std::endl; } + + int32_t toInt() override { return static_cast(value_); } + + int64_t toLong() override { return value_; } + +private: + int64_t value_; +}; + +struct Num { + std::unique_ptr num_ptr; + + // Constructions and destruction + Num() : num_ptr(nullptr) {} + + Num(std::unique_ptr num_ptr_) : num_ptr(std::move(num_ptr_)) {} + + Num &operator=(const Num &other) { + if (this != &other) { + num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr; + } + return *this; + } + + Num(const Num &other) { + num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr; + } + + Num(Num &&other) noexcept = default; + + Num &operator=(Num &&other) noexcept = default; + + ~Num() = default; + + int32_t toInt() const { return num_ptr->toInt(); } + + int32_t toLong() const { return num_ptr->toLong(); } + + void display() const { num_ptr->display(); } + + Num operator+(const Num &other) const { + if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return Num( + std::make_unique(I32V_t(this->toInt() + other.toInt()))); + } else if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return Num( + std::make_unique(I64V_t(this->toLong() + other.toLong()))); + } else { + throw std::runtime_error("Operands are of different types"); + } + } + + Num operator-(const Num &other) const { + if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return Num( + std::make_unique(I32V_t(this->toInt() - other.toInt()))); + } else if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return Num( + std::make_unique(I64V_t(this->toLong() - other.toLong()))); + } else { + throw std::runtime_error("Operands are of different types"); + } + } + + bool operator==(const Num &other) const { + if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return this->toInt() == other.toInt(); + } else if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return this->toLong() == other.toLong(); + } else { + throw std::runtime_error("Operands are of different types"); + } + } + + bool operator!=(const Num &other) const { return !(this->operator==(other)); } +}; + +static Num I32V(int v) { return Num(std::make_unique(v)); } + +static Num I64V(int64_t v) { return Num(std::make_unique(v)); } + +// struct Slice { +// int32_t start; +// int32_t end; +// Slice(int32_t start_, int32_t end_) : start(start_), end(end_) {} +// }; + +using Slice = std::vector; + +class Stack_t { +public: + void push(Num &&num) { + assert(num.num_ptr != nullptr); + stack_.push_back(std::move(num)); + } + + void push(Num &num) { + assert(num.num_ptr != nullptr); + stack_.push_back(num); + } + + Num pop() { + if (stack_.empty()) { + throw std::runtime_error("Stack underflow"); + } + Num num = std::move(stack_.back()); + assert(num.num_ptr != nullptr); + stack_.pop_back(); + return num; + } + + Num peek() { + if (stack_.empty()) { + throw std::runtime_error("Stack underflow"); + } + return stack_.back(); + } + + Num get(int32_t index) { + assert(index >= 0); + assert(index < stack_.size()); + return stack_[index]; +} + + int32_t size() { return stack_.size(); } + + void reset(int32_t size) { + if (size > stack_.size()) { + throw std::out_of_range("Invalid size"); + } + while (stack_.size() > size) { + stack_.pop_back(); + } + } + + Slice take(int32_t size) { + if (size > stack_.size()) { + throw std::out_of_range("Invalid size"); + } + // todo: avoid re-allocation + Slice slice(stack_.end() - size, stack_.end()); + stack_.resize(stack_.size() - size); + return slice; + } + + void print() { + std::cout << "Stack contents: " << std::endl; + for (const auto &num : stack_) { + num.display(); + } + } + + void initialize() { stack_.clear(); } + +private: + std::vector stack_; +}; +static Stack_t Stack; + +struct Frame_t { + std::vector locals; + + Frame_t(std::int32_t size) : locals() { locals.resize(size); } + Num &operator[](std::int32_t index) { + assert(index >= 0); + if (index >= locals.size()) { + throw std::out_of_range("Index out of range"); + } + return locals[index]; + } + void putAll(Slice slice) { + for (std::int32_t i = 0; i < slice.size(); ++i) { + locals[i] = slice[i]; + } + } +}; + +class Frames_t { +public: + std::monostate popFrame() { + if (!frames.empty()) { + frames.pop_back(); + return std::monostate{}; + } else { + std::cout << "No frames to pop." << std::endl; + throw std::runtime_error("No frames to pop."); + } + } + + Num get(std::int32_t index) { + auto ret = top()[index]; + assert(ret.num_ptr != nullptr); + return ret; + } + + void set(std::int32_t index, Num num) { frames.back()[index] = num; } + + Frame_t &top() { + if (frames.empty()) { + throw std::runtime_error("No frames available"); + } + return frames.back(); + } + + void pushFrame(std::int32_t size) { + Frame_t frame(size); + frames.push_back(frame); + } + + void putAll(Slice slice) { + top().putAll(slice); + } + +private: + std::vector frames; +}; + +static Frames_t Frames; + +static void initRand() { + // for now, just do nothing +} + """ } @@ -817,5 +1117,6 @@ object WasmToCppCompiler { } code.code } + } From d5ed20d52078a28bf66ba87139b9571ae558d049 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 17 May 2025 14:57:20 +0800 Subject: [PATCH 37/40] emit functions --- src/main/scala/wasm/StagedMiniWasm.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index aecd6391..774b921d 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -801,6 +801,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { |Emitting Generated Code |*******************************************/ """.stripMargin) + emitln(""" #include #include @@ -808,6 +809,9 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { #include #include """) val src = run(name, ng) + emitFunctionDecls(stream) + emitDatastructures(stream) + emitFunctions(stream) emit(src) emitln(""" |/***************************************** From 4fb5424e076019d2910849047fe3aafb4b06aa79 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 18 May 2025 10:36:42 +0800 Subject: [PATCH 38/40] read a dummy node to avoid lambda lifting it seems that the lambda lifting is unsound --- src/main/scala/wasm/StagedMiniWasm.scala | 35 +++++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 774b921d..25037e27 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -88,18 +88,26 @@ trait StagedWasmEvaluator extends SAIOps { // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType + val dummy = "dummy".reflectCtrlWith[Unit]() // TODO: somehow the type of exitSize in residual program is nothing def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + info(s"Exiting the block, stackSize =", Stack.size) + "dummy-op".reflectCtrlWith[Unit](dummy) eval(rest, kont, trail) }) eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val dummy = "dummy".reflectCtrlWith[Unit]() def restK = fun((_: Rep[Unit]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + info(s"Exiting the loop, stackSize =", Stack.size) eval(rest, kont, trail) }) def loop : Rep[Unit => Unit] = fun((_u: Rep[Unit]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + info(s"Entered the loop, stackSize =", Stack.size) eval(inner, restK, loop :: trail) }) loop(()) @@ -107,8 +115,11 @@ trait StagedWasmEvaluator extends SAIOps { val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val cond = Stack.pop() + val dummy = "dummy".reflectCtrlWith[Unit]() // TODO: can we avoid code duplication here? def restK = fun((_: Rep[Unit]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + info(s"Exiting the if, stackSize =", Stack.size) eval(rest, kont, trail) }) if (cond != Values.I32(0)) { @@ -121,7 +132,7 @@ trait StagedWasmEvaluator extends SAIOps { trail(label)(()) case BrIf(label) => val cond = Stack.pop() - info(s"The br_if(${label})'s condition is ", cond) + info(s"The br_if(${label})'s condition is ", cond.toInt) if (cond != Values.I32(0)) { info(s"Jump to $label") trail(label)(()) @@ -157,14 +168,14 @@ trait StagedWasmEvaluator extends SAIOps { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => val returnSize = Stack.size - ty.inps.size + ty.out.size val args = Stack.take(ty.inps.size) - info("New frame:", Frames.top) + // info("New frame:", Frames.top) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) } else { val callee = topFun( (kont: Rep[Cont[Unit]]) => { - info(s"Entered the function at $funcIndex, stackSize =", Stack.size, ", frame =", Frames.top) + info(s"Entered the function at $funcIndex, stackSize =", Stack.size) eval(body, kont, kont::Nil): Rep[Unit] } ) @@ -180,6 +191,7 @@ trait StagedWasmEvaluator extends SAIOps { callee(trail.last) } else { val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) Frames.popFrame() eval(rest, kont, trail) }) @@ -269,6 +281,7 @@ trait StagedWasmEvaluator extends SAIOps { def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { val haltK: Rep[Unit] => Rep[Unit] = (_) => { + info("Exiting the program...") if (printRes) { Stack.print() } @@ -773,6 +786,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "num-to-int", List(num), _) => shallow(num); emit(".toInt()") + case Node(_, "dummy", _, _) => emit("std::monostate()") + case Node(_, "dummy-op", _, _) => emit("std::monostate()") case Node(_, "no-op", _, _) => emit("std::monostate()") case _ => super.shallow(n) @@ -833,7 +848,19 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { #include #include -#define info(x, ...) +void info() { +#ifdef DEBUG + std::cout << std::endl; +#endif +} + +template +void info(const T &first, const Args &...args) { +#ifdef DEBUG + std::cout << first << " "; + info(args...); +#endif +} class Num_t { public: From 8e293b8dbe5648d84669e39fbdefcd771c9f2cb0 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 18 May 2025 10:49:51 +0800 Subject: [PATCH 39/40] capture by value is not friendly with recursion --- src/main/scala/wasm/StagedMiniWasm.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 25037e27..7920c0dc 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -727,6 +727,14 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") case Node(_, "global-set", List(i, value), _) => emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") + case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => + // Node: This code is copied from the traverse of CppSAICodeGenBase.scala, try to avoid code duplication + val retType = remap(typeBlockRes(b.res)) + val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") + emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") + emit(quote(f)); emit(" = ") + quoteTypedBlock(b, false, true, capture = "&") + emitln(";") case _ => super.traverse(n) } From 51dd632c3fb990f0cabdb7134b2e63138b09a14c Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 18 May 2025 11:00:08 +0800 Subject: [PATCH 40/40] redirect generated code to a file --- src/main/scala/wasm/StagedMiniWasm.scala | 4 ++-- src/test/scala/genwasym/TestStagedEval.scala | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 7920c0dc..2bea6cd5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -681,7 +681,7 @@ object Main { object WasmToScalaCompiler { - def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") val code = new WasmToScalaCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst @@ -1146,7 +1146,7 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv } object WasmToCppCompiler { - def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") val code = new WasmToCppCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index c96f9b7e..a74fbc6f 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -10,7 +10,7 @@ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { def testFileToScala(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = WasmToScalaCompiler(moduleInst, main, true) + val code = WasmToScalaCompiler.compile(moduleInst, main, true) println(code) } @@ -26,7 +26,15 @@ class TestStagedEval extends FunSuite { def testFileToCpp(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = WasmToCppCompiler(moduleInst, main, true) + val code = WasmToCppCompiler.compile(moduleInst, main, true) + if (printRes) { + val writer = new java.io.PrintWriter(new java.io.File(s"$filename.cpp")) + try { + writer.write(code) + } finally { + writer.close() + } + } println(code) }