Skip to content

Commit

Permalink
c#: Handle Cabi realloc post return (bytecodealliance#1145)
Browse files Browse the repository at this point in the history
* wire in post return abi

* Deallocate allocs on exports

Signed-off-by: James Sturtevant <[email protected]>

* Fix cleanup and rebase

Signed-off-by: James Sturtevant <[email protected]>

* Add a large string array test

Signed-off-by: James Sturtevant <[email protected]>

* borrow the list

Signed-off-by: James Sturtevant <[email protected]>

* Fix language tests

Signed-off-by: James Sturtevant <[email protected]>

---------

Signed-off-by: James Sturtevant <[email protected]>
  • Loading branch information
jsturtevant authored Feb 13, 2025
1 parent 58897bf commit cd5e771
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 57 deletions.
180 changes: 126 additions & 54 deletions crates/csharp/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
// );
}

Instruction::ListCanonLower { element, realloc } => {
Instruction::ListCanonLower { element, .. } => {
let list: &String = &operands[0];
match self.interface_gen.direction {
Direction::Import => {
Expand Down Expand Up @@ -755,29 +755,20 @@ impl Bindgen for FunctionBindgen<'_, '_> {
results.push(format!("({list}).Length"));
}
Direction::Export => {
let (_, ty) = list_element_info(element);
let address = self.locals.tmp("address");
let buffer = self.locals.tmp("buffer");
let gc_handle = self.locals.tmp("gcHandle");
let size = self.interface_gen.csharp_gen.sizes.size(element).size_wasm32();
let byte_length = self.locals.tmp("byteLength");
uwrite!(
self.src,
"
byte[] {buffer} = new byte[({size}) * {list}.Length];
Buffer.BlockCopy({list}.ToArray(), 0, {buffer}, 0, ({size}) * {list}.Length);
var {gc_handle} = GCHandle.Alloc({buffer}, GCHandleType.Pinned);
var {address} = {gc_handle}.AddrOfPinnedObject();
var {byte_length} = ({size}) * {list}.Length;
var {address} = NativeMemory.Alloc((nuint)({byte_length}));
{list}.AsSpan().CopyTo(new Span<{ty}>({address},{byte_length}));
"
);

if realloc.is_none() {
self.needs_cleanup = true;
uwrite!(
self.src,
"
cleanups.Add(()=> {gc_handle}.Free());
");
}
results.push(format!("((IntPtr)({address})).ToInt32()"));
results.push(format!("(int)({address})"));
results.push(format!("{list}.Length"));
}
}
Expand All @@ -802,33 +793,45 @@ impl Bindgen for FunctionBindgen<'_, '_> {

Instruction::StringLower { realloc } => {
let op = &operands[0];
let interop_string = self.locals.tmp("interopString");
let str_ptr = self.locals.tmp("strPtr");
let utf8_bytes = self.locals.tmp("utf8Bytes");
let length = self.locals.tmp("length");
let gc_handle = self.locals.tmp("gcHandle");
uwriteln!(
self.src,
"
var {utf8_bytes} = Encoding.UTF8.GetBytes({op});
var {length} = {utf8_bytes}.Length;
var {gc_handle} = GCHandle.Alloc({utf8_bytes}, GCHandleType.Pinned);
var {interop_string} = {gc_handle}.AddrOfPinnedObject();
"
);

if realloc.is_none() {
results.push(format!("{interop_string}.ToInt32()"));
uwriteln!(
self.src,
"
var {utf8_bytes} = Encoding.UTF8.GetBytes({op});
var {length} = {utf8_bytes}.Length;
var {gc_handle} = GCHandle.Alloc({utf8_bytes}, GCHandleType.Pinned);
var {str_ptr} = {gc_handle}.AddrOfPinnedObject();
"
);

self.needs_cleanup = true;
uwrite!(
self.src,
"
cleanups.Add(()=> {gc_handle}.Free());
");
"
);
results.push(format!("{str_ptr}.ToInt32()"));
} else {
results.push(format!("{interop_string}.ToInt32()"));
let string_span = self.locals.tmp("stringSpan");
uwriteln!(
self.src,
"
var {string_span} = {op}.AsSpan();
var {length} = Encoding.UTF8.GetByteCount({string_span});
var {str_ptr} = NativeMemory.Alloc((nuint){length});
Encoding.UTF8.GetBytes({string_span}, new Span<byte>({str_ptr}, {length}));
"
);
results.push(format!("(int){str_ptr}"));
}
results.push(format!("{length}"));

results.push(format!("{length}"));
if FunctionKind::Freestanding == *self.kind || self.interface_gen.direction == Direction::Export {
self.interface_gen.require_interop_using("System.Text");
self.interface_gen.require_interop_using("System.Runtime.InteropServices");
Expand All @@ -851,7 +854,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
));
}

Instruction::ListLower { element, .. } => {
Instruction::ListLower { element, realloc } => {
let Block {
body,
results: block_results,
Expand All @@ -876,22 +879,38 @@ impl Bindgen for FunctionBindgen<'_, '_> {
);
let ret_area = self.locals.tmp("retArea");

self.needs_cleanup = true;
uwrite!(
self.src,
"
void* {address};
if (({size} * {list}.Count) < 1024) {{
var {ret_area} = stackalloc {element_type}[({array_size}*{list}.Count)+1];
{address} = (void*)(((int){ret_area}) + ({align} - 1) & -{align});
}}
else
{{
var {buffer_size} = {size} * (nuint){list}.Count;
{address} = NativeMemory.AlignedAlloc({buffer_size}, {align});
cleanups.Add(()=> NativeMemory.AlignedFree({address}));
}}
match realloc {
None => {
self.needs_cleanup = true;
uwrite!(self.src,
"
void* {address};
if (({size} * {list}.Count) < 1024) {{
var {ret_area} = stackalloc {element_type}[({array_size}*{list}.Count)+1];
{address} = (void*)(((int){ret_area}) + ({align} - 1) & -{align});
}}
else
{{
var {buffer_size} = {size} * (nuint){list}.Count;
{address} = NativeMemory.AlignedAlloc({buffer_size}, {align});
cleanups.Add(() => NativeMemory.AlignedFree({address}));
}}
"
);
}
Some(_) => {
//cabi_realloc_post_return will be called to clean up this allocation
uwrite!(self.src,
"
var {buffer_size} = {size} * (nuint){list}.Count;
void* {address} = NativeMemory.AlignedAlloc({buffer_size}, {align});
"
);
}
}

uwrite!(self.src,
"
for (int {index} = 0; {index} < {list}.Count; ++{index}) {{
{ty} {block_element} = {list}[{index}];
int {base} = (int){address} + ({index} * {size});
Expand Down Expand Up @@ -1035,7 +1054,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
}
}

Instruction::Return { amt: _, func } => {
Instruction::Return { amt, .. } => {
if self.fixed_statments.len() > 0 {
let fixed: String = self.fixed_statments.iter().map(|f| format!("{} = {}", f.ptr_name, f.item_to_pin)).collect::<Vec<_>>().join(", ");
self.src.insert_str(0, &format!("fixed (void* {fixed})
Expand All @@ -1055,7 +1074,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
}

if !matches!((self.interface_gen.direction, self.kind), (Direction::Import, FunctionKind::Constructor(_))) {
match func.results.len() {
match *amt {
0 => (),
1 => {
self.handle_result_import(operands);
Expand All @@ -1075,19 +1094,72 @@ impl Bindgen for FunctionBindgen<'_, '_> {
Instruction::Malloc { .. } => unimplemented!(),

Instruction::GuestDeallocate { .. } => {
uwriteln!(self.src, r#"Console.WriteLine("TODO: deallocate buffer for indirect parameters");"#);
// the original alloc here comes from cabi_realloc implementation (wasi-libc in .net)
uwriteln!(self.src, r#"NativeMemory.Free((void*){});"#, operands[0]);
}

Instruction::GuestDeallocateString => {
uwriteln!(self.src, r#"Console.WriteLine("TODO: deallocate buffer for string");"#);
uwriteln!(self.src, r#"NativeMemory.Free((void*){});"#, operands[0]);
}

Instruction::GuestDeallocateVariant { .. } => {
uwriteln!(self.src, r#"Console.WriteLine("TODO: deallocate buffer for variant");"#);
Instruction::GuestDeallocateVariant { blocks } => {
let cases = self
.blocks
.drain(self.blocks.len() - blocks..)
.enumerate()
.map(|(i, Block { body, results, .. })| {
assert!(results.is_empty());

format!(
"case {i}: {{
{body}
break;
}}"
)
})
.collect::<Vec<_>>()
.join("\n");

let op = &operands[0];

uwrite!(
self.src,
"
switch ({op}) {{
{cases}
}}
"
);
}

Instruction::GuestDeallocateList { .. } => {
uwriteln!(self.src, r#"Console.WriteLine("TODO: deallocate buffer for list");"#);
Instruction::GuestDeallocateList { element: element_type } => {
let Block {
body,
results: block_results,
base,
element: _,
} = self.blocks.pop().unwrap();
assert!(block_results.is_empty());

let address = &operands[0];
let length = &operands[1];
let size = self.interface_gen.csharp_gen.sizes.size(element_type).size_wasm32();

if !body.trim().is_empty() {
let index = self.locals.tmp("index");

uwrite!(
self.src,
"
for (int {index} = 0; {index} < {length}; ++{index}) {{
int {base} = (int){address} + ({index} * {size});
{body}
}}
"
);
}

uwriteln!(self.src, r#"NativeMemory.Free((void*){});"#, operands[0]);
}

Instruction::HandleLower {
Expand Down
30 changes: 27 additions & 3 deletions crates/csharp/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,37 @@ impl InterfaceGenerator<'_> {
"#
);

if !sig.results.is_empty() {
if abi::guest_export_needs_post_return(self.resolve, func) {
let params = sig
.results
.iter()
.enumerate()
.map(|(i, param)| {
let ty = crate::world_generator::wasm_type(*param);
format!("{ty} p{i}")
})
.collect::<Vec<_>>()
.join(", ");

let mut bindgen = FunctionBindgen::new(
self,
"INVALID",
&func.kind,
(0..sig.results.len()).map(|i| format!("p{i}")).collect(),
Vec::new(),
ParameterType::ABI,
);

abi::post_return(bindgen.interface_gen.resolve, func, &mut bindgen, false);

let src = bindgen.src;

uwrite!(
self.csharp_interop_src,
r#"
[UnmanagedCallersOnly(EntryPoint = "cabi_post_{export_name}")]
{access} static void cabi_post_{interop_name}({wasm_result_type} returnValue) {{
Console.WriteLine("TODO: cabi_post_{export_name}");
{access} static unsafe void cabi_post_{interop_name}({params}) {{
{src}
}}
"#
);
Expand Down
6 changes: 6 additions & 0 deletions tests/runtime/lists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ impl test::lists::test::Host for MyImports {
assert_eq!(ptr, [(1, 2, 3), (4, 5, 6)]);
}

fn list_param_large(&mut self, ptr: Vec<String>) {
assert_eq!(ptr.len(), 1000);
}

fn list_result(&mut self) -> Vec<u8> {
vec![1, 2, 3, 4, 5]
}
Expand Down Expand Up @@ -133,6 +137,8 @@ fn run_test(lists: Lists, store: &mut Store<crate::Wasi<MyImports>>) -> Result<(
vec!["baz".to_owned()],
],
)?;
let arg0: Vec<String> = (0..1000).map(|_| "string".to_string()).collect();
exports.call_list_param_large(&mut *store, &arg0)?;
assert_eq!(exports.call_list_result(&mut *store)?, [1, 2, 3, 4, 5]);
assert_eq!(exports.call_list_result2(&mut *store)?, "hello!");
assert_eq!(
Expand Down
5 changes: 5 additions & 0 deletions tests/runtime/lists/wasm.c
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ void exports_test_lists_test_list_param4(lists_list_list_string_t *a) {
lists_list_list_string_free(a);
}

void exports_test_lists_test_list_param_large(lists_list_string_t *a) {
assert(a->len == 1000);
lists_list_string_free(a);
}

void exports_test_lists_test_list_param5(lists_list_tuple3_u8_u32_u8_t *a) {
assert(a->len == 2);
assert(a->ptr[0].f0 == 1);
Expand Down
12 changes: 12 additions & 0 deletions tests/runtime/lists/wasm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ public static void TestImports()
}
});

List<string> randomStrings = new List<string>();
for (int i = 0; i < 1000; i++)
{
randomStrings.Add(Guid.NewGuid().ToString());
}
TestInterop.ListParamLarge(randomStrings);

{
byte[] result = TestInterop.ListResult();
Debug.Assert(result.Length == 5);
Expand Down Expand Up @@ -233,6 +240,11 @@ public static void ListParam5(List<(byte, uint, byte)> a)
Debug.Assert(a[1].Item3 == 6);
}

public static void ListParamLarge(List<String> a)
{
Debug.Assert(a.Count() == 1000);
}

public static byte[] ListResult()
{
return new byte[] { (byte)1, (byte)2, (byte)3, (byte)4, (byte)5 };
Expand Down
14 changes: 14 additions & 0 deletions tests/runtime/lists/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"math"
"math/rand"
"strconv"
. "wit_lists_go/gen"
)

Expand Down Expand Up @@ -29,6 +31,12 @@ func (i ListImpl) TestImports() {
TestListsTestListParam2("foo")
TestListsTestListParam3([]string{"foo", "bar", "baz"})
TestListsTestListParam4([][]string{{"foo", "bar"}, {"baz"}})

randomStrings := make([]string, 1000)
for i := 0; i < 1000; i++ {
randomStrings[i] = "str" + strconv.Itoa(rand.Intn(1000))
}
TestListsTestListParamLarge(randomStrings)
res3 := TestListsTestListResult()
if len(res3) != 5 {
panic("TestListsTestListResult")
Expand Down Expand Up @@ -212,6 +220,12 @@ func (i ListImpl) ListParam5(a []ExportsTestListsTestTuple3U8U32U8T) {
}
}

func (i ListImpl) ListParamLarge(a []string) {
if len(a) != 1000 {
panic("ListParamLarge")
}
}

func (i ListImpl) ListResult() []uint8 {
return []uint8{1, 2, 3, 4, 5}
}
Expand Down
Loading

0 comments on commit cd5e771

Please sign in to comment.