diff --git a/examples/rust/simd/Cargo.toml b/examples/rust/simd/Cargo.toml new file mode 100644 index 0000000..9a5a577 --- /dev/null +++ b/examples/rust/simd/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "simd" +version = "0.1.0" +edition = "2018" + +[dependencies] +wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen.git" } + +[lib] +crate-type = ["cdylib"] diff --git a/examples/rust/simd/README.md b/examples/rust/simd/README.md new file mode 100644 index 0000000..79d1474 --- /dev/null +++ b/examples/rust/simd/README.md @@ -0,0 +1,23 @@ +# Simd Rust + +A simple simd example using [core::arch::wasm32](https://doc.rust-lang.org/core/arch/wasm32/index.html#simd) + +## Build + +```sh +cargo build --target wasm32-wasi +``` + +## Create Function + +```sql +CREATE FUNCTION `u64x2-dot` AS WASM FROM INFILE 'target/wasm32-unknown-unknown/debug/simd.wasm' WITH WIT FROM INFILE 'simd.wit' +CREATE FUNCTION `u64x2-inner` RETURNS TABLE AS WASM FROM INFILE 'target/wasm32-unknown-unknown/debug/simd.wasm' WITH WIT FROM INFILE 'simd.wit' +``` + +## Example Queries + +```sql +SELECT * FROM `u64x2-dot`([1,2,3], [0,5,6]); +SELECT * FROM `u64x2-inner`([1,2,3], [3,4,5]); +``` diff --git a/examples/rust/simd/simd.wit b/examples/rust/simd/simd.wit new file mode 100644 index 0000000..8c53687 --- /dev/null +++ b/examples/rust/simd/simd.wit @@ -0,0 +1,4 @@ +u64x2-scalar-mul: func(a: u64, b: list) -> list +u64x2-dot: func(a: list, b: list) -> u64 +u64x2-inner: func(a: list, b: list) -> list +u64x2-mat-mul: func(a: list>, b: list>) -> list> diff --git a/examples/rust/simd/src/lib.rs b/examples/rust/simd/src/lib.rs new file mode 100644 index 0000000..108b1ce --- /dev/null +++ b/examples/rust/simd/src/lib.rs @@ -0,0 +1,89 @@ +#[cfg(target_arch = "wasm32")] +wit_bindgen_rust::export!("simd.wit"); + +struct Simd; + +use core::arch::wasm32::*; + +impl simd::Simd for Simd { + fn u64x2_scalar_mul(a: u64, b: Vec) -> Vec { + let va: v128 = u64x2_splat(a); + let n = b.len(); + let mut res: Vec = vec![0; b.len()]; + let mut i = 0; + while i + 1 < n { + let vb: v128 = u64x2(b[i], b[i + 1]); + let s: v128 = u64x2_mul(va, vb); + res[i] = u64x2_extract_lane::<0>(s); + res[i + 1] = u64x2_extract_lane::<0>(s); + i += 2; + } + for j in 1..(n % 2 + 1) { + res[n - j] = a * b[n - j]; + } + res + } + + fn u64x2_dot(a: Vec, b: Vec) -> u64 { + assert!(a.len() == b.len()); + let n = a.len(); + let mut sum: v128 = u64x2(0, 0); + let mut i = 0; + while i + 1 < n { + let va: v128 = u64x2(a[i], a[i + 1]); + let vb: v128 = u64x2(b[i], b[i + 1]); + sum = u64x2_add(sum, u64x2_mul(va, vb)); + i += 2; + } + for j in 1..(n % 2 + 1) { + return u64x2_extract_lane::<0>(sum) + + u64x2_extract_lane::<1>(sum) + + a[n - j] * b[n - j]; + } + u64x2_extract_lane::<0>(sum) + u64x2_extract_lane::<1>(sum) + } + + fn u64x2_inner(a: Vec, b: Vec) -> Vec { + assert!(a.len() == b.len()); + let n = a.len(); + let mut res = vec![0; n]; + let mut i = 0; + while i + 1 < n { + let va: v128 = u64x2(a[i], a[i + 1]); + let vb: v128 = u64x2(b[i], b[i + 1]); + let m: v128 = u64x2_mul(va, vb); + res[i] = u64x2_extract_lane::<0>(m); + res[i + 1] = u64x2_extract_lane::<1>(m); + i += 2; + } + for j in 1..(n % 2 + 1) { + res[n - 1] = a[n - j] * b[n - j]; + } + res + } + + fn u64x2_mat_mul(a: Vec>, b: Vec>) -> Vec> { + assert!(a.len() > 0 && b.len() > 0); + assert!(a[0].len() == b.len()); + + let mut res = vec![vec![0; a.len()]; b[0].len()]; + let n = a.len(); + let m = b.len(); + for i in 0..n { + for j in 0..b[0].len() { + let mut k = 0; + while k + 1 < m { + let va: v128 = u64x2(a[i][k], a[i][k + 1]); + let vb: v128 = u64x2(b[k][j], b[k + 1][j]); + let m: v128 = u64x2_mul(va, vb); + res[i][j] += u64x2_extract_lane::<0>(m) + u64x2_extract_lane::<1>(m); + k += 2; + } + for t in 1..(m % 2 + 1) { + res[i][j] += a[i][m - t] * b[m - t][j]; + } + } + } + res + } +}