diff --git a/Cargo.lock b/Cargo.lock index b832ae3..0dea597 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -220,6 +220,26 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bindgen" +version = "0.71.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +dependencies = [ + "bitflags 2.9.1", + "cexpr", + "clang-sys", + "itertools 0.13.0", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.104", +] + [[package]] name = "bit-set" version = "0.8.0" @@ -277,18 +297,58 @@ version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +[[package]] +name = "bytemuck" +version = "1.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c76a5792e44e4abe34d3abf15636779261d45a7450612059293d1d2cfc63422" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ecc273b49b3205b83d648f0690daa588925572cc5063745bfe547fe7ec8e1a1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "cast" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "2.34.0" @@ -448,7 +508,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -459,7 +519,7 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -508,7 +568,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -557,7 +617,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -676,6 +736,12 @@ dependencies = [ "wasi 0.14.2+wasi-0.2.4", ] +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "half" version = "1.8.3" @@ -777,6 +843,16 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -829,6 +905,12 @@ dependencies = [ "paste", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "mopro-msm" version = "0.2.0" @@ -841,6 +923,8 @@ dependencies = [ "ark-relations", "ark-serialize", "ark-std", + "bindgen", + "bytemuck", "criterion", "enumset", "instant", @@ -860,6 +944,16 @@ dependencies = [ "toml", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -988,6 +1082,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +dependencies = [ + "proc-macro2", + "syn 2.0.104", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -1135,6 +1239,12 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -1244,7 +1354,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1290,7 +1400,7 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1304,6 +1414,12 @@ dependencies = [ "digest", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "slab" version = "0.4.9" @@ -1338,9 +1454,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.101" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -1561,7 +1677,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", "wasm-bindgen-shared", ] @@ -1583,7 +1699,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1746,7 +1862,7 @@ checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] [[package]] @@ -1766,5 +1882,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.104", ] diff --git a/mopro-msm/Cargo.toml b/mopro-msm/Cargo.toml index d340ccd..22bee35 100644 --- a/mopro-msm/Cargo.toml +++ b/mopro-msm/Cargo.toml @@ -40,11 +40,14 @@ rayon = "1.5.1" itertools = "0.13.0" rand = "0.8.5" +bytemuck = { version = "1.15", features = ["derive"] } + [build-dependencies] enumset = "1.0.8" toml = "0.8" serde = { version = "1.0", features = ["derive"] } serde_derive = "1.0" +bindgen = "0.71" [dev-dependencies] serial_test = "3.0.0" diff --git a/mopro-msm/build.rs b/mopro-msm/build.rs index 4e2e7ae..fd71636 100644 --- a/mopro-msm/build.rs +++ b/mopro-msm/build.rs @@ -15,6 +15,8 @@ fn compile_shaders() -> std::io::Result<()> { let shader_out_dir = out_dir.join("shaders"); fs::create_dir_all(&shader_out_dir)?; + build_cpp_header(manifest_dir.clone()); + // Check if we should compile all shaders for testing // Check environment variable first, then check if this is a test build let compile_all_shaders = env::var("MSM_COMPILE_ALL_SHADERS") @@ -172,6 +174,38 @@ fn compile_shaders() -> std::io::Result<()> { Ok(()) } +fn build_cpp_header(root_dir: PathBuf) { + println!("cargo:rerun-if-changed=src/msm/metal_msm/shader/cuzk/Common.h"); + + // macOS SDK root for clang + let sdk_root = String::from_utf8( + std::process::Command::new("xcrun") + .args(["--sdk", "macosx", "--show-sdk-path"]) + .output() + .unwrap() + .stdout, + ) + .unwrap() + .trim() + .to_owned(); + + let bindings = bindgen::Builder::default() + .header(format!( + "{}/src/msm/metal_msm/shader/cuzk/Common.h", + root_dir.to_str().unwrap() + )) + .clang_arg(format!("-isysroot{}", sdk_root)) + // .clang_arg("-x") // Objective-C dialect so #import works + // .clang_arg("objective-c") + .allowlist_type("Uniforms|Params") + .allowlist_type("BufferIndices|Attributes|TextureIndices") + .generate() + .expect("bindgen failed"); + + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings.write_to_file(out_dir.join("common.rs")).unwrap(); +} + fn detect_metal_version(sdk: &str) -> std::io::Result<(String, String, bool)> { // Use OS version to determine Metal version support let os_version = get_os_version(); diff --git a/mopro-msm/src/lib.rs b/mopro-msm/src/lib.rs index 3020b7b..60ac3bc 100644 --- a/mopro-msm/src/lib.rs +++ b/mopro-msm/src/lib.rs @@ -1,4 +1,5 @@ pub mod msm; +pub mod types; use thiserror::Error; #[derive(Debug, Error)] diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/Common.h b/mopro-msm/src/msm/metal_msm/shader/cuzk/Common.h new file mode 100644 index 0000000..9dc9413 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/Common.h @@ -0,0 +1,23 @@ +#ifndef Common_h +#define Common_h + +#include + +typedef struct { + matrix_float4x4 trueMatrix; + matrix_float4x4 falseMatrix; + matrix_float4x4 otherMatrix; +} Uniforms; + +typedef struct { + int width; + int height; +} Params; + +typedef enum { + VertexBuffer = 0, + ParamsBuffer = 2 +} BufferIndices; + + +#endif /* Common_h */ diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/smvp.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/smvp.metal index 6bc4d0e..b7065ce 100644 --- a/mopro-msm/src/msm/metal_msm/shader/cuzk/smvp.metal +++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/smvp.metal @@ -1,6 +1,7 @@ #include "../curve/jacobian.metal" #include "barrett_reduction.metal" #include +#include "Common.h" using namespace metal; #if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320) @@ -37,6 +38,8 @@ kernel void smvp( const uint subtask_idx = id / half_columns; + BufferIndices some = BufferIndices::ParamsBuffer; + Jacobian inf = get_bn254_zero_mont(); // an offset for each subtask's row_ptr diff --git a/mopro-msm/src/msm/metal_msm/tests/cuzk/e2e.rs b/mopro-msm/src/msm/metal_msm/tests/cuzk/e2e.rs index 31837db..3751fa6 100644 --- a/mopro-msm/src/msm/metal_msm/tests/cuzk/e2e.rs +++ b/mopro-msm/src/msm/metal_msm/tests/cuzk/e2e.rs @@ -1,4 +1,5 @@ use crate::msm::metal_msm::metal_msm::metal_variable_base_msm; +use crate::types::raw::*; use ark_bn254::{Fr as ScalarField, G1Projective as G}; use ark_ec::CurveGroup; use ark_ec::VariableBaseMSM; @@ -23,6 +24,13 @@ mod tests { ); let start = std::time::Instant::now(); + let _ = BufferIndices_ParamsBuffer; + + let _params = Params { + width: 1920, + height: 1080, + }; + // Generate bases and scalars in parallel let (bases, scalars): (Vec<_>, Vec<_>) = (0..num_threads) .into_par_iter() diff --git a/mopro-msm/src/types.rs b/mopro-msm/src/types.rs new file mode 100644 index 0000000..ebf5709 --- /dev/null +++ b/mopro-msm/src/types.rs @@ -0,0 +1,14 @@ +#![allow(non_snake_case, non_camel_case_types, non_upper_case_globals)] + +pub mod raw { + // 1. pull in the bindgen output that build.rs dropped in OUT_DIR + include!(concat!(env!("OUT_DIR"), "/common.rs")); + + // 2. add Pod/Zeroable so we can cast to &[u8] safely + use bytemuck::{Pod, Zeroable}; + + unsafe impl Zeroable for Uniforms {} + unsafe impl Pod for Uniforms {} + unsafe impl Zeroable for Params {} + unsafe impl Pod for Params {} +}