Skip to content

Commit 5805d7d

Browse files
committed
feat: add lift/lower flat_values
Signed-off-by: Gordon Smith <[email protected]>
1 parent 0432dd4 commit 5805d7d

File tree

10 files changed

+417
-215
lines changed

10 files changed

+417
-215
lines changed

Cargo.lock

Lines changed: 252 additions & 145 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ version = "0.1.0"
44
edition = "2024"
55

66
[dependencies]
7-
wasm-tools = "1.227.1"
7+
wasm-tools = "1.229.0"
88
wit-bindgen-cli = "0.41.0"

include/cmcpp/func.hpp

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
namespace cmcpp
77
{
88

9-
namespace funcXXX
9+
namespace func
1010
{
1111
// template <Flags T>
1212
// int32_t pack_flags_into_int(const T &v)
@@ -49,37 +49,63 @@ namespace cmcpp
4949
// std::memcpy(&buff, &i, ValTrait<T>::size);
5050
// return unpack_flags_from_int<T>(i);
5151
// }
52-
}
5352

54-
template <Func T>
55-
inline void flatten(LiftLowerContext &cx, const T &v, uint32_t ptr)
56-
{
57-
flags::store(cx, v, ptr);
58-
}
53+
enum class ContextType
54+
{
55+
Lift,
56+
Lower
57+
};
5958

60-
// template <Flags T>
61-
// inline void store(LiftLowerContext &cx, const T &v, uint32_t ptr)
62-
// {
63-
// flags::store(cx, v, ptr);
64-
// }
59+
template <Func T>
60+
inline core_func_t flatten(LiftLowerContext &cx, ContextType context)
61+
{
62+
std::vector<WasmValType> flat_params(ValTrait<T>::flat_params_types.begin(), ValTrait<T>::flat_params_types.end());
63+
std::vector<WasmValType> flat_results(ValTrait<T>::flat_result_types.begin(), ValTrait<T>::flat_result_types.end());
64+
// if (cx.opts.sync == true)
65+
{
66+
if (flat_params.size() > MAX_FLAT_PARAMS)
67+
{
68+
flat_params = {WasmValType::i32};
69+
}
70+
if (flat_results.size() > MAX_FLAT_RESULTS)
71+
{
72+
switch (context)
73+
{
74+
case ContextType::Lift:
75+
flat_results = {WasmValType::i32};
76+
break;
77+
case ContextType::Lower:
78+
flat_params.push_back(WasmValType::i32);
79+
flat_results = {};
80+
}
81+
}
82+
}
83+
return {flat_params, flat_results};
84+
}
6585

66-
// template <Flags T>
67-
// inline WasmValVector lower_flat(LiftLowerContext &cx, const T &v)
68-
// {
69-
// return flags::lower_flat(cx, v);
70-
// }
86+
// template <Flags T>
87+
// inline void store(LiftLowerContext &cx, const T &v, uint32_t ptr)
88+
// {
89+
// flags::store(cx, v, ptr);
90+
// }
7191

72-
// template <Flags T>
73-
// inline T load(const LiftLowerContext &cx, uint32_t ptr)
74-
// {
75-
// return flags::load<T>(cx, ptr);
76-
// }
92+
// template <Flags T>
93+
// inline WasmValVector lower_flat(LiftLowerContext &cx, const T &v)
94+
// {
95+
// return flags::lower_flat(cx, v);
96+
// }
7797

78-
// template <Flags T>
79-
// inline T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi)
80-
// {
81-
// return flags::lift_flat<T>(cx, vi);
82-
// }
83-
}
98+
// template <Flags T>
99+
// inline T load(const LiftLowerContext &cx, uint32_t ptr)
100+
// {
101+
// return flags::load<T>(cx, ptr);
102+
// }
84103

104+
// template <Flags T>
105+
// inline T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi)
106+
// {
107+
// return flags::lift_flat<T>(cx, vi);
108+
// }
109+
}
110+
}
85111
#endif

include/cmcpp/lift.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "context.hpp"
55
#include "util.hpp"
6+
#include "load.hpp"
67

78
namespace cmcpp
89
{
@@ -48,6 +49,27 @@ namespace cmcpp
4849
template <Option T>
4950
inline T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi);
5051

52+
template <Tuple T>
53+
inline T lift_heap_values(const LiftLowerContext &cx, const CoreValueIter &vi)
54+
{
55+
uint32_t ptr = vi.next<int32_t>();
56+
using tuple_type = typename std::tuple_element<0, typename ValTrait<T>::inner_type>::type;
57+
trap_if(cx, ptr != align_to(ptr, ValTrait<tuple_type>::alignment));
58+
trap_if(cx, ptr + ValTrait<tuple_type>::size > cx.opts.memory.size());
59+
auto retVal = load<tuple_type>(cx, ptr);
60+
return retVal;
61+
}
62+
63+
template <Tuple T>
64+
inline T lift_flat_values(const LiftLowerContext &cx, uint max_flat, const CoreValueIter &vi)
65+
{
66+
auto flat_types = ValTrait<T>::flat_types;
67+
if (flat_types.size() > max_flat)
68+
{
69+
return lift_heap_values<T>(cx, vi);
70+
}
71+
return lift_flat<T>(cx, vi);
72+
}
5173
}
5274

5375
#endif

include/cmcpp/list.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ namespace cmcpp
102102
template <List T>
103103
inline T load(const LiftLowerContext &cx, uint32_t ptr)
104104
{
105-
return list::load<T>(cx, ptr);
105+
return list::load<typename ValTrait<T>::inner_type>(cx, ptr);
106106
}
107107

108108
template <List T>

include/cmcpp/lower.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "list.hpp"
99
#include "flags.hpp"
1010
#include "tuple.hpp"
11+
#include "func.hpp"
1112
#include "util.hpp"
1213

1314
#include <tuple>
@@ -63,6 +64,36 @@ namespace cmcpp
6364

6465
template <Option T>
6566
inline WasmValVector lower_flat(LiftLowerContext &cx, const T &v);
67+
68+
template <Tuple T>
69+
inline WasmValVector lower_heap_values(LiftLowerContext &cx, const T &vs)
70+
{
71+
using tuple_type = tuple_t<T>;
72+
tuple_type tuple_value = vs;
73+
auto ptr = cx.opts.realloc(0, 0, ValTrait<T>::alignment, ValTrait<T>::size);
74+
WasmValVector flat_vals = {ptr};
75+
trap_if(cx, ptr != align_to(ptr, ValTrait<tuple_type>::alignment));
76+
trap_if(cx, ptr + ValTrait<tuple_type>::size > cx.opts.memory.size());
77+
return flat_vals;
78+
}
79+
80+
template <Tuple T>
81+
inline WasmValVector lower_flat_values(LiftLowerContext &cx, uint max_flat, const T &vs)
82+
{
83+
// cx.inst.may_leave=false;
84+
WasmValVector retVal = {};
85+
auto flat_types = ValTrait<T>::flat_types;
86+
if (flat_types.size() > max_flat)
87+
{
88+
retVal = lower_heap_values(cx, vs);
89+
}
90+
else
91+
{
92+
retVal = lower_flat(cx, vs);
93+
}
94+
// cx.inst.may_leave=true;
95+
return retVal;
96+
}
6697
}
6798

6899
#endif

include/cmcpp/traits.hpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ namespace cmcpp
109109
using type = float64_t;
110110
};
111111

112+
class core_func_t
113+
{
114+
public:
115+
std::vector<WasmValType> params;
116+
std::vector<WasmValType> results;
117+
};
112118
// --------------------------------------------------------------------
113119

114120
enum class ValType : uint8_t
@@ -692,20 +698,32 @@ namespace cmcpp
692698
concept Option = ValTrait<T>::type == ValType::Option;
693699

694700
// Func --------------------------------------------------------------------
695-
constexpr int MAX_FLAT_PARAMS = 16;
696-
constexpr int MAX_FLAT_RESULTS = 1;
701+
constexpr uint MAX_FLAT_PARAMS = 16;
702+
constexpr uint MAX_FLAT_RESULTS = 1;
703+
704+
template <typename>
705+
struct func_t_impl;
697706

698707
template <Field R, Field... Args>
699-
using func_t = std::function<R(Args...)>;
708+
struct func_t_impl<R(Args...)>
709+
{
710+
using type = std::function<R(Args...)>;
711+
};
712+
713+
template <typename F>
714+
using func_t = typename func_t_impl<F>::type;
715+
700716
template <Field R, Field... Args>
701-
struct ValTrait<func_t<R, Args...>>
717+
struct ValTrait<std::function<R(Args...)>>
702718
{
703719
static constexpr ValType type = ValType::Func;
720+
using inner_type = std::function<R(Args...)>;
704721
using params_t = tuple_t<Args...>;
705-
using result_t = R;
722+
using results_t = tuple_t<R>;
706723
static constexpr auto flat_params_types = ValTrait<params_t>::flat_types;
707-
static constexpr auto flat_result_types = ValTrait<result_t>::flat_types;
724+
static constexpr auto flat_result_types = ValTrait<results_t>::flat_types;
708725
};
726+
709727
template <typename T>
710728
concept Func = ValTrait<T>::type == ValType::Func;
711729

samples/wamr/main.cpp

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,6 @@ WasmValVector fromWamr(size_t count, const wasm_val_t *values)
110110
assert(false);
111111
}
112112
}
113-
// for (size_t i = count; i < target_count; i++)
114-
// {
115-
// result[i] = (int32_t)(std::get<int32_t>(result[count - 1]) + (i * alignment));
116-
// }
117113
return result;
118114
}
119115

@@ -163,17 +159,18 @@ int main()
163159

164160
LiftLowerContext liftLowerContext(trap, convert, opts);
165161

166-
using and_func_t = std::function<bool_t(bool_t, bool_t)>;
162+
using and_func_t = func_t<bool_t(bool_t, bool_t)>;
167163
auto and_func = wasm_runtime_lookup_function(module_inst, "example:sample/booleans#and");
168164
and_func_t call_and = [&](bool_t a, bool_t b) -> bool_t
169165
{
170-
using inputs_t = ValTrait<and_func_t>::params_t;
171-
using outputs_t = ValTrait<and_func_t>::result_t;
172-
auto inputs = toWamr(lower_flat(liftLowerContext, inputs_t{a, b}));
166+
using params_t = ValTrait<and_func_t>::params_t;
167+
using results_t = ValTrait<and_func_t>::results_t;
168+
169+
auto inputs = toWamr(lower_flat_values(liftLowerContext, MAX_FLAT_PARAMS, params_t{a, b}));
173170
auto output_size = 1;
174171
wasm_val_t outputs[output_size];
175172
auto call_result = wasm_runtime_call_wasm_a(exec_env, and_func, output_size, outputs, inputs.size(), inputs.data());
176-
auto result = lift_flat<outputs_t>(liftLowerContext, fromWamr<outputs_t>(output_size, outputs));
173+
auto result = std::get<0>(lift_flat_values<results_t>(liftLowerContext, MAX_FLAT_RESULTS, fromWamr<results_t>(output_size, outputs)));
177174
std::cout << "and_func(" << a << ", " << b << "): " << result << std::endl;
178175
return result;
179176
};
@@ -182,77 +179,78 @@ int main()
182179
call_and(true, false);
183180
call_and(true, true);
184181

185-
using add_func_t = std::function<float64_t(float64_t, float64_t)>;
182+
using add_func_t = func_t<float64_t(float64_t, float64_t)>;
186183
auto add_func = wasm_runtime_lookup_function(module_inst, "example:sample/floats#add");
187184
add_func_t call_add = [&](float64_t input1, float64_t input2) -> float64_t
188185
{
189-
using inputs_t = ValTrait<add_func_t>::params_t;
190-
using outputs_t = ValTrait<add_func_t>::result_t;
186+
using params_t = ValTrait<add_func_t>::params_t;
187+
using results_t = ValTrait<add_func_t>::results_t;
191188

192-
auto inputs = toWamr(lower_flat(liftLowerContext, inputs_t{input1, input2}));
189+
auto inputs = toWamr(lower_flat_values(liftLowerContext, MAX_FLAT_PARAMS, params_t{input1, input2}));
193190
auto output_size = 1;
194191
wasm_val_t outputs[output_size];
195192
auto call_result = wasm_runtime_call_wasm_a(exec_env, add_func, output_size, outputs, inputs.size(), inputs.data());
196-
auto result = lift_flat<outputs_t>(liftLowerContext, fromWamr<outputs_t>(output_size, outputs));
193+
auto result = std::get<0>(lift_flat_values<results_t>(liftLowerContext, MAX_FLAT_RESULTS, fromWamr<results_t>(output_size, outputs)));
197194
std::cout << "add_func(" << input1 << ", " << input2 << "): " << result << std::endl;
198195
return result;
199196
};
200197
call_add(3.1, 0.2);
201198

202-
using reverse_func_t = std::function<string_t(string_t)>;
199+
using reverse_func_t = func_t<string_t(string_t)>;
203200
auto reverse_func = wasm_runtime_lookup_function(module_inst, "example:sample/strings#reverse");
204201
auto reverse_cleanup_func = wasm_runtime_lookup_function(module_inst, "cabi_post_example:sample/strings#reverse");
205202
reverse_func_t call_reverse = [&](string_t input1) -> string_t
206203
{
207-
using inputs_t = ValTrait<reverse_func_t>::params_t;
208-
using outputs_t = ValTrait<reverse_func_t>::result_t;
204+
auto flat_ft_lower = func::flatten<reverse_func_t>(liftLowerContext, func::ContextType::Lower);
205+
auto flat_ft_lift = func::flatten<reverse_func_t>(liftLowerContext, func::ContextType::Lift);
206+
207+
using params_t = ValTrait<reverse_func_t>::params_t;
208+
using results_t = ValTrait<reverse_func_t>::results_t;
209209

210-
auto inputs = toWamr(lower_flat(liftLowerContext, inputs_t{input1}));
210+
auto inputs = toWamr(lower_flat_values(liftLowerContext, MAX_FLAT_PARAMS, params_t{input1}));
211211
auto output_size = 1;
212212
wasm_val_t outputs[output_size];
213213
auto call_result = wasm_runtime_call_wasm_a(exec_env, reverse_func, output_size, outputs, inputs.size(), inputs.data());
214-
// auto result = load<outputs_t>(liftLowerContext, outputs[0].of.i32);
215-
auto result = lift_flat<outputs_t>(liftLowerContext, fromWamr<outputs_t>(output_size, outputs));
214+
auto result = std::get<0>(lift_flat_values<results_t>(liftLowerContext, MAX_FLAT_RESULTS, fromWamr<results_t>(output_size, outputs)));
216215
std::cout << "reverse_string(" << input1 << "): " << result << std::endl;
217-
call_result = wasm_runtime_call_wasm_a(exec_env, reverse_cleanup_func, 0, nullptr, 1, outputs);
218216
return result;
219217
};
220218
auto call_reverse_result = call_reverse("Hello World!");
221219
call_reverse(call_reverse_result);
222220

223-
using reverse_tuple_func_t = std::function<tuple_t<string_t, bool_t>(tuple_t<bool_t, string_t>)>;
221+
using reverse_tuple_func_t = func_t<tuple_t<string_t, bool_t>(tuple_t<bool_t, string_t>)>;
224222
auto reverse_tuple_func = wasm_runtime_lookup_function(module_inst, "example:sample/tuples#reverse");
225223
auto reverse_tuple_cleanup_func = wasm_runtime_lookup_function(module_inst, "cabi_post_example:sample/tuples#reverse");
226224
reverse_tuple_func_t call_reverse_tuple = [&](tuple_t<bool_t, string_t> a) -> tuple_t<string_t, bool_t>
227225
{
228-
using inputs_t = ValTrait<reverse_tuple_func_t>::params_t;
229-
using outputs_t = ValTrait<reverse_tuple_func_t>::result_t;
226+
using params_t = ValTrait<reverse_tuple_func_t>::params_t;
227+
using results_t = ValTrait<reverse_tuple_func_t>::results_t;
230228

231-
auto inputs = toWamr(lower_flat(liftLowerContext, a));
229+
auto inputs = toWamr(lower_flat_values(liftLowerContext, MAX_FLAT_PARAMS, params_t{a}));
232230
auto output_size = 1;
233231
wasm_val_t outputs[output_size];
234232
auto call_result = wasm_runtime_call_wasm_a(exec_env, reverse_tuple_func, output_size, outputs, inputs.size(), inputs.data());
235-
auto result = load<outputs_t>(liftLowerContext, outputs->of.i32);
233+
auto result = std::get<0>(lift_flat_values<results_t>(liftLowerContext, MAX_FLAT_RESULTS, fromWamr<results_t>(output_size, outputs)));
236234
std::cout << "reverse_tuple(" << std::get<0>(a) << ", " << std::get<1>(a) << "): " << std::get<0>(result) << ", " << std::get<1>(result) << std::endl;
237235
call_result = wasm_runtime_call_wasm_a(exec_env, reverse_tuple_cleanup_func, 0, nullptr, 1, outputs);
238236
return result;
239237
};
240238
auto call_reverse_tuple_result = call_reverse_tuple({false, "Hello World!"});
241239
// call_reverse_tuple({std::get<1>(call_reverse_tuple_result), std::get<0>(call_reverse_tuple_result}));
242240

243-
using list_filter_bool_func_t = std::function<list_t<string_t>(list_t<variant_t<bool_t, string_t>>)>;
241+
using list_filter_bool_func_t = func_t<list_t<string_t>(list_t<variant_t<bool_t, string_t>>)>;
244242
auto list_filter_bool_func = wasm_runtime_lookup_function(module_inst, "example:sample/lists#filter-bool");
245243
auto list_filter_bool_cleanup_func = wasm_runtime_lookup_function(module_inst, "cabi_post_example:sample/lists#filter-bool");
246244
auto call_list_filter_bool = [&](list_t<variant_t<bool_t, string_t>> a) -> list_t<string_t>
247245
{
248-
using inputs_t = ValTrait<list_filter_bool_func_t>::params_t;
249-
using outputs_t = ValTrait<list_filter_bool_func_t>::result_t;
246+
using params_t = ValTrait<list_filter_bool_func_t>::params_t;
247+
using results_t = ValTrait<list_filter_bool_func_t>::results_t;
250248

251-
auto inputs = toWamr(lower_flat(liftLowerContext, a));
249+
auto inputs = toWamr(lower_flat_values(liftLowerContext, MAX_FLAT_PARAMS, params_t{a}));
252250
auto output_size = 1;
253251
wasm_val_t outputs[output_size];
254252
auto call_result = wasm_runtime_call_wasm_a(exec_env, list_filter_bool_func, output_size, outputs, inputs.size(), inputs.data());
255-
auto result = lift_flat<outputs_t>(liftLowerContext, fromWamr<outputs_t>(output_size, outputs));
253+
auto result = std::get<0>(lift_flat_values<results_t>(liftLowerContext, MAX_FLAT_RESULTS, fromWamr<results_t>(output_size, outputs)));
256254
std::cout << "list_filter_bool(" << a.size() << "): " << result.size() << std::endl;
257255
call_result = wasm_runtime_call_wasm_a(exec_env, list_filter_bool_cleanup_func, 0, nullptr, 1, outputs);
258256
return result;

test/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,8 @@ TEST_CASE("Func")
604604
Heap heap(1024 * 1024);
605605
auto cx = createLiftLowerContext(&heap, Encoding::Utf8);
606606

607-
using MyFunc = func_t<uint32_t, string_t, list_t<string_t>>;
608-
MyFunc f = [](string_t b, list_t<string_t> c) -> uint32_t
607+
using MyFunc = func_t<uint32_t(string_t, string_t)>;
608+
MyFunc f = [](string_t b, string_t c) -> uint32_t
609609
{
610610
return 42;
611611
};

0 commit comments

Comments
 (0)