@@ -8,6 +8,7 @@ import lms.macros.SourceContext
8
8
import lms .core .stub .{Base , ScalaGenBase , CGenBase }
9
9
import lms .core .Backend ._
10
10
import lms .core .Backend .{Block => LMSBlock }
11
+ import lms .core .Graph
11
12
12
13
import gensym .wasm .ast ._
13
14
import gensym .wasm .ast .{Const => WasmConst , Block => WasmBlock }
@@ -88,20 +89,17 @@ trait StagedWasmEvaluator extends SAIOps {
88
89
// the type system guarantees that we will never take more than the input size from the stack
89
90
val funcTy = ty.funcType
90
91
// TODO: somehow the type of exitSize in residual program is nothing
91
- val exitSize = Stack .size - funcTy.inps.size + funcTy.out.size
92
- def restK : Rep [Cont [Unit ]] = topFun((_ : Rep [Unit ]) => {
93
- Stack .reset(exitSize)
92
+ def restK : Rep [Cont [Unit ]] = fun((_ : Rep [Unit ]) => {
94
93
eval(rest, kont, trail)
95
94
})
96
95
eval(inner, restK, restK :: trail)
97
96
case Loop (ty, inner) =>
98
97
val funcTy = ty.funcType
99
98
val exitSize = Stack .size - funcTy.inps.size + funcTy.out.size
100
- def restK = topFun((_ : Rep [Unit ]) => {
101
- Stack .reset(exitSize)
99
+ def restK = fun((_ : Rep [Unit ]) => {
102
100
eval(rest, kont, trail)
103
101
})
104
- def loop : Rep [Unit => Unit ] = topFun ((_u : Rep [Unit ]) => {
102
+ def loop : Rep [Unit => Unit ] = fun ((_u : Rep [Unit ]) => {
105
103
eval(inner, restK, loop :: trail)
106
104
})
107
105
loop(())
@@ -110,8 +108,7 @@ trait StagedWasmEvaluator extends SAIOps {
110
108
val exitSize = Stack .size - funcTy.inps.size + funcTy.out.size
111
109
val cond = Stack .pop()
112
110
// TODO: can we avoid code duplication here?
113
- def restK = topFun((_ : Rep [Unit ]) => {
114
- Stack .reset(exitSize)
111
+ def restK = fun((_ : Rep [Unit ]) => {
115
112
eval(rest, kont, trail)
116
113
})
117
114
if (cond != Values .I32 (0 )) {
@@ -182,8 +179,7 @@ trait StagedWasmEvaluator extends SAIOps {
182
179
Frames .putAll(args)
183
180
callee(trail.last)
184
181
} else {
185
- val restK : Rep [Cont [Unit ]] = topFun((_ : Rep [Unit ]) => {
186
- Stack .reset(returnSize)
182
+ val restK : Rep [Cont [Unit ]] = fun((_ : Rep [Unit ]) => {
187
183
Frames .popFrame()
188
184
eval(rest, kont, trail)
189
185
})
@@ -278,7 +274,7 @@ trait StagedWasmEvaluator extends SAIOps {
278
274
}
279
275
" no-op" .reflectCtrlWith[Unit ]()
280
276
}
281
- val temp : Rep [Cont [Unit ]] = topFun (haltK)
277
+ val temp : Rep [Cont [Unit ]] = fun (haltK)
282
278
evalTop(temp, main)
283
279
}
284
280
@@ -796,6 +792,310 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase {
796
792
} else {
797
793
withStream(functionsStreams(id)._1)(f)
798
794
}
795
+
796
+ override def emitAll (g : Graph , name : String )(m1 : Manifest [_], m2 : Manifest [_]): Unit = {
797
+ val ng = init(g)
798
+ emitln(prelude)
799
+ emitln("""
800
+ |/*****************************************
801
+ |Emitting Generated Code
802
+ |*******************************************/
803
+ """ .stripMargin)
804
+ emitln("""
805
+ #include <functional>
806
+ #include <stdbool.h>
807
+ #include <stdint.h>
808
+ #include <string>
809
+ #include <variant>""" )
810
+ val src = run(name, ng)
811
+ emit(src)
812
+ emitln("""
813
+ |/*****************************************
814
+ |End of Generated Code
815
+ |*******************************************/
816
+ |int main(int argc, char *argv[]) {
817
+ | Snippet(std::monostate{});
818
+ | return 0;
819
+ |}""" .stripMargin)
820
+ }
821
+
822
+ val prelude = """
823
+ #include <cassert>
824
+ #include <cstdint>
825
+ #include <cstdio>
826
+ #include <iostream>
827
+ #include <memory>
828
+ #include <ostream>
829
+ #include <variant>
830
+ #include <vector>
831
+
832
+ #define info(x, ...)
833
+
834
+ class Num_t {
835
+ public:
836
+ virtual std::unique_ptr<Num_t> clone() const = 0;
837
+
838
+ virtual void display() = 0;
839
+ virtual int32_t toInt() = 0;
840
+ virtual int64_t toLong() = 0;
841
+ };
842
+
843
+ class I32V_t : public Num_t {
844
+ public:
845
+ I32V_t(int32_t value) : value_(value) {}
846
+
847
+ std::unique_ptr<Num_t> clone() const override {
848
+ return std::make_unique<I32V_t>(*this);
849
+ }
850
+
851
+ void display() override { std::cout << value_ << std::endl; }
852
+
853
+ int32_t toInt() override { return value_; }
854
+
855
+ int64_t toLong() override { return static_cast<int64_t>(value_); }
856
+
857
+ private:
858
+ int32_t value_;
859
+ };
860
+
861
+ class I64V_t : public Num_t {
862
+ public:
863
+ I64V_t(int64_t value) : value_(value) {}
864
+
865
+ std::unique_ptr<Num_t> clone() const override {
866
+ return std::make_unique<I64V_t>(*this);
867
+ }
868
+
869
+ void display() override { std::cout << value_ << std::endl; }
870
+
871
+ int32_t toInt() override { return static_cast<int32_t>(value_); }
872
+
873
+ int64_t toLong() override { return value_; }
874
+
875
+ private:
876
+ int64_t value_;
877
+ };
878
+
879
+ struct Num {
880
+ std::unique_ptr<Num_t> num_ptr;
881
+
882
+ // Constructions and destruction
883
+ Num() : num_ptr(nullptr) {}
884
+
885
+ Num(std::unique_ptr<Num_t> num_ptr_) : num_ptr(std::move(num_ptr_)) {}
886
+
887
+ Num &operator=(const Num &other) {
888
+ if (this != &other) {
889
+ num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr;
890
+ }
891
+ return *this;
892
+ }
893
+
894
+ Num(const Num &other) {
895
+ num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr;
896
+ }
897
+
898
+ Num(Num &&other) noexcept = default;
899
+
900
+ Num &operator=(Num &&other) noexcept = default;
901
+
902
+ ~Num() = default;
903
+
904
+ int32_t toInt() const { return num_ptr->toInt(); }
905
+
906
+ int32_t toLong() const { return num_ptr->toLong(); }
907
+
908
+ void display() const { num_ptr->display(); }
909
+
910
+ Num operator+(const Num &other) const {
911
+ if (dynamic_cast<I32V_t *>(num_ptr.get()) &&
912
+ dynamic_cast<I32V_t *>(other.num_ptr.get())) {
913
+ return Num(
914
+ std::make_unique<I32V_t>(I32V_t(this->toInt() + other.toInt())));
915
+ } else if (dynamic_cast<I64V_t *>(num_ptr.get()) &&
916
+ dynamic_cast<I64V_t *>(other.num_ptr.get())) {
917
+ return Num(
918
+ std::make_unique<I64V_t>(I64V_t(this->toLong() + other.toLong())));
919
+ } else {
920
+ throw std::runtime_error("Operands are of different types");
921
+ }
922
+ }
923
+
924
+ Num operator-(const Num &other) const {
925
+ if (dynamic_cast<I32V_t *>(num_ptr.get()) &&
926
+ dynamic_cast<I32V_t *>(other.num_ptr.get())) {
927
+ return Num(
928
+ std::make_unique<I32V_t>(I32V_t(this->toInt() - other.toInt())));
929
+ } else if (dynamic_cast<I64V_t *>(num_ptr.get()) &&
930
+ dynamic_cast<I64V_t *>(other.num_ptr.get())) {
931
+ return Num(
932
+ std::make_unique<I64V_t>(I64V_t(this->toLong() - other.toLong())));
933
+ } else {
934
+ throw std::runtime_error("Operands are of different types");
935
+ }
936
+ }
937
+
938
+ bool operator==(const Num &other) const {
939
+ if (dynamic_cast<I32V_t *>(num_ptr.get()) &&
940
+ dynamic_cast<I32V_t *>(other.num_ptr.get())) {
941
+ return this->toInt() == other.toInt();
942
+ } else if (dynamic_cast<I64V_t *>(num_ptr.get()) &&
943
+ dynamic_cast<I64V_t *>(other.num_ptr.get())) {
944
+ return this->toLong() == other.toLong();
945
+ } else {
946
+ throw std::runtime_error("Operands are of different types");
947
+ }
948
+ }
949
+
950
+ bool operator!=(const Num &other) const { return !(this->operator==(other)); }
951
+ };
952
+
953
+ static Num I32V(int v) { return Num(std::make_unique<I32V_t>(v)); }
954
+
955
+ static Num I64V(int64_t v) { return Num(std::make_unique<I64V_t>(v)); }
956
+
957
+ // struct Slice {
958
+ // int32_t start;
959
+ // int32_t end;
960
+ // Slice(int32_t start_, int32_t end_) : start(start_), end(end_) {}
961
+ // };
962
+
963
+ using Slice = std::vector<Num>;
964
+
965
+ class Stack_t {
966
+ public:
967
+ void push(Num &&num) {
968
+ assert(num.num_ptr != nullptr);
969
+ stack_.push_back(std::move(num));
970
+ }
971
+
972
+ void push(Num &num) {
973
+ assert(num.num_ptr != nullptr);
974
+ stack_.push_back(num);
975
+ }
976
+
977
+ Num pop() {
978
+ if (stack_.empty()) {
979
+ throw std::runtime_error("Stack underflow");
980
+ }
981
+ Num num = std::move(stack_.back());
982
+ assert(num.num_ptr != nullptr);
983
+ stack_.pop_back();
984
+ return num;
985
+ }
986
+
987
+ Num peek() {
988
+ if (stack_.empty()) {
989
+ throw std::runtime_error("Stack underflow");
990
+ }
991
+ return stack_.back();
992
+ }
993
+
994
+ Num get(int32_t index) {
995
+ assert(index >= 0);
996
+ assert(index < stack_.size());
997
+ return stack_[index];
998
+ }
999
+
1000
+ int32_t size() { return stack_.size(); }
1001
+
1002
+ void reset(int32_t size) {
1003
+ if (size > stack_.size()) {
1004
+ throw std::out_of_range("Invalid size");
1005
+ }
1006
+ while (stack_.size() > size) {
1007
+ stack_.pop_back();
1008
+ }
1009
+ }
1010
+
1011
+ Slice take(int32_t size) {
1012
+ if (size > stack_.size()) {
1013
+ throw std::out_of_range("Invalid size");
1014
+ }
1015
+ // todo: avoid re-allocation
1016
+ Slice slice(stack_.end() - size, stack_.end());
1017
+ stack_.resize(stack_.size() - size);
1018
+ return slice;
1019
+ }
1020
+
1021
+ void print() {
1022
+ std::cout << "Stack contents: " << std::endl;
1023
+ for (const auto &num : stack_) {
1024
+ num.display();
1025
+ }
1026
+ }
1027
+
1028
+ void initialize() { stack_.clear(); }
1029
+
1030
+ private:
1031
+ std::vector<Num> stack_;
1032
+ };
1033
+ static Stack_t Stack;
1034
+
1035
+ struct Frame_t {
1036
+ std::vector<Num> locals;
1037
+
1038
+ Frame_t(std::int32_t size) : locals() { locals.resize(size); }
1039
+ Num &operator[](std::int32_t index) {
1040
+ assert(index >= 0);
1041
+ if (index >= locals.size()) {
1042
+ throw std::out_of_range("Index out of range");
1043
+ }
1044
+ return locals[index];
1045
+ }
1046
+ void putAll(Slice slice) {
1047
+ for (std::int32_t i = 0; i < slice.size(); ++i) {
1048
+ locals[i] = slice[i];
1049
+ }
1050
+ }
1051
+ };
1052
+
1053
+ class Frames_t {
1054
+ public:
1055
+ std::monostate popFrame() {
1056
+ if (!frames.empty()) {
1057
+ frames.pop_back();
1058
+ return std::monostate{};
1059
+ } else {
1060
+ std::cout << "No frames to pop." << std::endl;
1061
+ throw std::runtime_error("No frames to pop.");
1062
+ }
1063
+ }
1064
+
1065
+ Num get(std::int32_t index) {
1066
+ auto ret = top()[index];
1067
+ assert(ret.num_ptr != nullptr);
1068
+ return ret;
1069
+ }
1070
+
1071
+ void set(std::int32_t index, Num num) { frames.back()[index] = num; }
1072
+
1073
+ Frame_t &top() {
1074
+ if (frames.empty()) {
1075
+ throw std::runtime_error("No frames available");
1076
+ }
1077
+ return frames.back();
1078
+ }
1079
+
1080
+ void pushFrame(std::int32_t size) {
1081
+ Frame_t frame(size);
1082
+ frames.push_back(frame);
1083
+ }
1084
+
1085
+ void putAll(Slice slice) {
1086
+ top().putAll(slice);
1087
+ }
1088
+
1089
+ private:
1090
+ std::vector<Frame_t> frames;
1091
+ };
1092
+
1093
+ static Frames_t Frames;
1094
+
1095
+ static void initRand() {
1096
+ // for now, just do nothing
1097
+ }
1098
+ """
799
1099
}
800
1100
801
1101
@@ -817,5 +1117,6 @@ object WasmToCppCompiler {
817
1117
}
818
1118
code.code
819
1119
}
1120
+
820
1121
}
821
1122
0 commit comments