Skip to content

Commit ed9c8e4

Browse files
an almost work runtime
1 parent 9f04722 commit ed9c8e4

File tree

1 file changed

+312
-11
lines changed

1 file changed

+312
-11
lines changed

src/main/scala/wasm/StagedMiniWasm.scala

Lines changed: 312 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import lms.macros.SourceContext
88
import lms.core.stub.{Base, ScalaGenBase, CGenBase}
99
import lms.core.Backend._
1010
import lms.core.Backend.{Block => LMSBlock}
11+
import lms.core.Graph
1112

1213
import gensym.wasm.ast._
1314
import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock}
@@ -88,20 +89,17 @@ trait StagedWasmEvaluator extends SAIOps {
8889
// the type system guarantees that we will never take more than the input size from the stack
8990
val funcTy = ty.funcType
9091
// TODO: somehow the type of exitSize in residual program is nothing
91-
val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size
92-
def restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => {
93-
Stack.reset(exitSize)
92+
def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => {
9493
eval(rest, kont, trail)
9594
})
9695
eval(inner, restK, restK :: trail)
9796
case Loop(ty, inner) =>
9897
val funcTy = ty.funcType
9998
val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size
100-
def restK = topFun((_: Rep[Unit]) => {
101-
Stack.reset(exitSize)
99+
def restK = fun((_: Rep[Unit]) => {
102100
eval(rest, kont, trail)
103101
})
104-
def loop : Rep[Unit => Unit] = topFun((_u: Rep[Unit]) => {
102+
def loop : Rep[Unit => Unit] = fun((_u: Rep[Unit]) => {
105103
eval(inner, restK, loop :: trail)
106104
})
107105
loop(())
@@ -110,8 +108,7 @@ trait StagedWasmEvaluator extends SAIOps {
110108
val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size
111109
val cond = Stack.pop()
112110
// TODO: can we avoid code duplication here?
113-
def restK = topFun((_: Rep[Unit]) => {
114-
Stack.reset(exitSize)
111+
def restK = fun((_: Rep[Unit]) => {
115112
eval(rest, kont, trail)
116113
})
117114
if (cond != Values.I32(0)) {
@@ -182,8 +179,7 @@ trait StagedWasmEvaluator extends SAIOps {
182179
Frames.putAll(args)
183180
callee(trail.last)
184181
} else {
185-
val restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => {
186-
Stack.reset(returnSize)
182+
val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => {
187183
Frames.popFrame()
188184
eval(rest, kont, trail)
189185
})
@@ -278,7 +274,7 @@ trait StagedWasmEvaluator extends SAIOps {
278274
}
279275
"no-op".reflectCtrlWith[Unit]()
280276
}
281-
val temp: Rep[Cont[Unit]] = topFun(haltK)
277+
val temp: Rep[Cont[Unit]] = fun(haltK)
282278
evalTop(temp, main)
283279
}
284280

@@ -796,6 +792,310 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase {
796792
} else {
797793
withStream(functionsStreams(id)._1)(f)
798794
}
795+
796+
override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = {
797+
val ng = init(g)
798+
emitln(prelude)
799+
emitln("""
800+
|/*****************************************
801+
|Emitting Generated Code
802+
|*******************************************/
803+
""".stripMargin)
804+
emitln("""
805+
#include <functional>
806+
#include <stdbool.h>
807+
#include <stdint.h>
808+
#include <string>
809+
#include <variant>""")
810+
val src = run(name, ng)
811+
emit(src)
812+
emitln("""
813+
|/*****************************************
814+
|End of Generated Code
815+
|*******************************************/
816+
|int main(int argc, char *argv[]) {
817+
| Snippet(std::monostate{});
818+
| return 0;
819+
|}""".stripMargin)
820+
}
821+
822+
val prelude = """
823+
#include <cassert>
824+
#include <cstdint>
825+
#include <cstdio>
826+
#include <iostream>
827+
#include <memory>
828+
#include <ostream>
829+
#include <variant>
830+
#include <vector>
831+
832+
#define info(x, ...)
833+
834+
class Num_t {
835+
public:
836+
virtual std::unique_ptr<Num_t> clone() const = 0;
837+
838+
virtual void display() = 0;
839+
virtual int32_t toInt() = 0;
840+
virtual int64_t toLong() = 0;
841+
};
842+
843+
class I32V_t : public Num_t {
844+
public:
845+
I32V_t(int32_t value) : value_(value) {}
846+
847+
std::unique_ptr<Num_t> clone() const override {
848+
return std::make_unique<I32V_t>(*this);
849+
}
850+
851+
void display() override { std::cout << value_ << std::endl; }
852+
853+
int32_t toInt() override { return value_; }
854+
855+
int64_t toLong() override { return static_cast<int64_t>(value_); }
856+
857+
private:
858+
int32_t value_;
859+
};
860+
861+
class I64V_t : public Num_t {
862+
public:
863+
I64V_t(int64_t value) : value_(value) {}
864+
865+
std::unique_ptr<Num_t> clone() const override {
866+
return std::make_unique<I64V_t>(*this);
867+
}
868+
869+
void display() override { std::cout << value_ << std::endl; }
870+
871+
int32_t toInt() override { return static_cast<int32_t>(value_); }
872+
873+
int64_t toLong() override { return value_; }
874+
875+
private:
876+
int64_t value_;
877+
};
878+
879+
struct Num {
880+
std::unique_ptr<Num_t> num_ptr;
881+
882+
// Constructions and destruction
883+
Num() : num_ptr(nullptr) {}
884+
885+
Num(std::unique_ptr<Num_t> num_ptr_) : num_ptr(std::move(num_ptr_)) {}
886+
887+
Num &operator=(const Num &other) {
888+
if (this != &other) {
889+
num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr;
890+
}
891+
return *this;
892+
}
893+
894+
Num(const Num &other) {
895+
num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr;
896+
}
897+
898+
Num(Num &&other) noexcept = default;
899+
900+
Num &operator=(Num &&other) noexcept = default;
901+
902+
~Num() = default;
903+
904+
int32_t toInt() const { return num_ptr->toInt(); }
905+
906+
int32_t toLong() const { return num_ptr->toLong(); }
907+
908+
void display() const { num_ptr->display(); }
909+
910+
Num operator+(const Num &other) const {
911+
if (dynamic_cast<I32V_t *>(num_ptr.get()) &&
912+
dynamic_cast<I32V_t *>(other.num_ptr.get())) {
913+
return Num(
914+
std::make_unique<I32V_t>(I32V_t(this->toInt() + other.toInt())));
915+
} else if (dynamic_cast<I64V_t *>(num_ptr.get()) &&
916+
dynamic_cast<I64V_t *>(other.num_ptr.get())) {
917+
return Num(
918+
std::make_unique<I64V_t>(I64V_t(this->toLong() + other.toLong())));
919+
} else {
920+
throw std::runtime_error("Operands are of different types");
921+
}
922+
}
923+
924+
Num operator-(const Num &other) const {
925+
if (dynamic_cast<I32V_t *>(num_ptr.get()) &&
926+
dynamic_cast<I32V_t *>(other.num_ptr.get())) {
927+
return Num(
928+
std::make_unique<I32V_t>(I32V_t(this->toInt() - other.toInt())));
929+
} else if (dynamic_cast<I64V_t *>(num_ptr.get()) &&
930+
dynamic_cast<I64V_t *>(other.num_ptr.get())) {
931+
return Num(
932+
std::make_unique<I64V_t>(I64V_t(this->toLong() - other.toLong())));
933+
} else {
934+
throw std::runtime_error("Operands are of different types");
935+
}
936+
}
937+
938+
bool operator==(const Num &other) const {
939+
if (dynamic_cast<I32V_t *>(num_ptr.get()) &&
940+
dynamic_cast<I32V_t *>(other.num_ptr.get())) {
941+
return this->toInt() == other.toInt();
942+
} else if (dynamic_cast<I64V_t *>(num_ptr.get()) &&
943+
dynamic_cast<I64V_t *>(other.num_ptr.get())) {
944+
return this->toLong() == other.toLong();
945+
} else {
946+
throw std::runtime_error("Operands are of different types");
947+
}
948+
}
949+
950+
bool operator!=(const Num &other) const { return !(this->operator==(other)); }
951+
};
952+
953+
static Num I32V(int v) { return Num(std::make_unique<I32V_t>(v)); }
954+
955+
static Num I64V(int64_t v) { return Num(std::make_unique<I64V_t>(v)); }
956+
957+
// struct Slice {
958+
// int32_t start;
959+
// int32_t end;
960+
// Slice(int32_t start_, int32_t end_) : start(start_), end(end_) {}
961+
// };
962+
963+
using Slice = std::vector<Num>;
964+
965+
class Stack_t {
966+
public:
967+
void push(Num &&num) {
968+
assert(num.num_ptr != nullptr);
969+
stack_.push_back(std::move(num));
970+
}
971+
972+
void push(Num &num) {
973+
assert(num.num_ptr != nullptr);
974+
stack_.push_back(num);
975+
}
976+
977+
Num pop() {
978+
if (stack_.empty()) {
979+
throw std::runtime_error("Stack underflow");
980+
}
981+
Num num = std::move(stack_.back());
982+
assert(num.num_ptr != nullptr);
983+
stack_.pop_back();
984+
return num;
985+
}
986+
987+
Num peek() {
988+
if (stack_.empty()) {
989+
throw std::runtime_error("Stack underflow");
990+
}
991+
return stack_.back();
992+
}
993+
994+
Num get(int32_t index) {
995+
assert(index >= 0);
996+
assert(index < stack_.size());
997+
return stack_[index];
998+
}
999+
1000+
int32_t size() { return stack_.size(); }
1001+
1002+
void reset(int32_t size) {
1003+
if (size > stack_.size()) {
1004+
throw std::out_of_range("Invalid size");
1005+
}
1006+
while (stack_.size() > size) {
1007+
stack_.pop_back();
1008+
}
1009+
}
1010+
1011+
Slice take(int32_t size) {
1012+
if (size > stack_.size()) {
1013+
throw std::out_of_range("Invalid size");
1014+
}
1015+
// todo: avoid re-allocation
1016+
Slice slice(stack_.end() - size, stack_.end());
1017+
stack_.resize(stack_.size() - size);
1018+
return slice;
1019+
}
1020+
1021+
void print() {
1022+
std::cout << "Stack contents: " << std::endl;
1023+
for (const auto &num : stack_) {
1024+
num.display();
1025+
}
1026+
}
1027+
1028+
void initialize() { stack_.clear(); }
1029+
1030+
private:
1031+
std::vector<Num> stack_;
1032+
};
1033+
static Stack_t Stack;
1034+
1035+
struct Frame_t {
1036+
std::vector<Num> locals;
1037+
1038+
Frame_t(std::int32_t size) : locals() { locals.resize(size); }
1039+
Num &operator[](std::int32_t index) {
1040+
assert(index >= 0);
1041+
if (index >= locals.size()) {
1042+
throw std::out_of_range("Index out of range");
1043+
}
1044+
return locals[index];
1045+
}
1046+
void putAll(Slice slice) {
1047+
for (std::int32_t i = 0; i < slice.size(); ++i) {
1048+
locals[i] = slice[i];
1049+
}
1050+
}
1051+
};
1052+
1053+
class Frames_t {
1054+
public:
1055+
std::monostate popFrame() {
1056+
if (!frames.empty()) {
1057+
frames.pop_back();
1058+
return std::monostate{};
1059+
} else {
1060+
std::cout << "No frames to pop." << std::endl;
1061+
throw std::runtime_error("No frames to pop.");
1062+
}
1063+
}
1064+
1065+
Num get(std::int32_t index) {
1066+
auto ret = top()[index];
1067+
assert(ret.num_ptr != nullptr);
1068+
return ret;
1069+
}
1070+
1071+
void set(std::int32_t index, Num num) { frames.back()[index] = num; }
1072+
1073+
Frame_t &top() {
1074+
if (frames.empty()) {
1075+
throw std::runtime_error("No frames available");
1076+
}
1077+
return frames.back();
1078+
}
1079+
1080+
void pushFrame(std::int32_t size) {
1081+
Frame_t frame(size);
1082+
frames.push_back(frame);
1083+
}
1084+
1085+
void putAll(Slice slice) {
1086+
top().putAll(slice);
1087+
}
1088+
1089+
private:
1090+
std::vector<Frame_t> frames;
1091+
};
1092+
1093+
static Frames_t Frames;
1094+
1095+
static void initRand() {
1096+
// for now, just do nothing
1097+
}
1098+
"""
7991099
}
8001100

8011101

@@ -817,5 +1117,6 @@ object WasmToCppCompiler {
8171117
}
8181118
code.code
8191119
}
1120+
8201121
}
8211122

0 commit comments

Comments
 (0)