diff --git a/.gitignore b/.gitignore
index b96f1de3..3c0349dc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,4 @@
.vscode
.DS_Store
target/
-vectors/
-graph.bin
**/*.metallib
\ No newline at end of file
diff --git a/Cargo.lock b/Cargo.lock
index 3c218daa..b832ae35 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -831,7 +831,7 @@ dependencies = [
[[package]]
name = "mopro-msm"
-version = "0.1.0"
+version = "0.2.0"
dependencies = [
"ark-bn254",
"ark-crypto-primitives",
diff --git a/Makefile b/Makefile
index af4d9b3f..017f6446 100644
--- a/Makefile
+++ b/Makefile
@@ -1,7 +1,10 @@
-# Existing Makefile content...
-
# Clean target to remove .metal.lib and .metal.ir files
clean_ir:
find ./mopro-msm/src/msm/metal_msm/shader -type f \( -name "*.metal.lib" -o -name "*.metal.ir" \) -delete
.PHONY: clean
+
+# Format MSL shaders using clang-format
+format_shaders:
+ find mopro-msm/src/msm/metal_msm/shader -name "*.metal" -exec xcrun clang-format -i --style=WebKit {} \;
+.PHONY: format_shaders
diff --git a/README.md b/README.md
index cf5ec7c6..21a7a6a0 100644
--- a/README.md
+++ b/README.md
@@ -1,39 +1,193 @@
-# mopro msm gpu-acceleration
+# Metal MSM
-We are researching and implementing methods to accelerate multi-scalar multiplication (MSM) on IOS mobile device.
+Metal-MSM v2 executes MSM on [BN254](https://hackmd.io/@jpw/bn254) curve on Apple GPUs using Metal Shading Language (MSL). Unlike v1, which naively split the work into smaller tasks, v2 takes [Tal and Koh’s WebGPU MSM](https://github.com/z-prize/2023-entries/tree/main/prize-2-msm-wasm/webgpu-only/tal-derei-koh-wei-jie) in ZPrize2023 and the cuZK [[LWY+23](https://eprint.iacr.org/2022/1321)] approach as reference.
-## mopro-msm
+By adopting sparse matrices, it improves the Pippenger algorithm [Pip76](https://dl.acm.org/doi/10.1109/SFCS.1976.21) with a more memory-efficient storage format and uses well-studied sparse matrix algorithms, such as sparse matrix–vector multiplication and sparse matrix transposition, in both the preprocessing phase (e.g., radix sort via sparse matrix transpose) and the bucket-accumulation phase to achieve high parallelism.
-This is a of various implementations of MSM functions, which are then integrated in `mopro-core`.
+We took the WebGPU MSM reference and tuned it for all scales by auto-adjusting workgroup sizes for each cuZK shaders with SIMD width and the amount of GPU cores, squeezing out better GPU utilization. Plus, with dynamic window sizes, we speed up small and medium inputs (2^14 – 2^18) by eliminating unused sparse-matrix columns.
-### Run benchmark on the laptop
-Currently we support these MSM algorithms on BN254:
-- arkworks_pippenger
-- bucket_wise_msm
-- precompute_msm
-- metal::msm (GPU)
+One thing to highlight is that our implementation runs most computations on the GPU, but it’s still slower than the CPU-only solution like [Arkworks](https://github.com/arkworks-rs). However, because we target client-side devices with limited resources, applying a hybrid approach, leveraging both CPU and GPU for MSM tasks and combining the results at the end, can yield an implementation slightly faster than a pure-CPU one. Check the write-up below for estimated speedups with this hybrid method.
-Replace `MSM_ALGO` with one of the algorithms above to get the corresponding benchmarks.
+## How to use
-Benchmarking for single instance size:
-```sh
-cargo test --release --package mopro-msm --lib -- msm::MSM_ALGO::tests::test_run_benchmark --exact --nocapture
+Metal MSM v2 works with `arkworks v0.4.x`; just include the crate in your `Cargo.toml`.
+```toml
+mopro-msm = { git = "https://github.com/zkmopro/gpu-acceleration.git", tag = "v0.2.0" }
```
-Benchmarking for multiple instance size:
-```sh
-cargo test --release --package mopro-msm --lib -- msm::MSM_ALGO::tests::test_run_multi_benchmarks --exact --nocapture
+Next, invoke MSM within your Rust code.
+```rust
+use mopro_msm::msm::metal_msm::{
+ metal_variable_base_msm,
+ test_utils::generate_random_bases_and_scalars, // optional
+};
+
+fn main() {
+ let input_size = 1 << 16;
+ let (bases, scalars) = generate_random_bases_and_scalars(input_size);
+ let msm_result = metal_variable_base_msm(&bases, &scalars);
+
+ println!("Result: {:?}", msm_result);
+}
```
-## gpu-exploration-app
+Because it’s compatible with Arkworks, you can seamlessly swap between Metal MSM and the Arkworks MSM implementation.
+```rust
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use ark_bn254::{Fr as ScalarField, G1Projective as G};
+ use ark_ec::{CurveGroup, VariableBaseMSM};
+ use ark_std::{UniformRand, test_rng};
+
+ #[test]
+ fn test_msm() {
+ let input_size = 1 << 10;
-This is a benchmark app to compare the performance of different algorithms on iOS device.
+ // Generate random EC points and scalars with Arkworks
+ let mut rng = test_rng();
+ let bases = (0..input_size)
+ .map(|_| G::rand(&mut rng).into_affine())
+ .collect::>();
+ let scalars = (0..input_size)
+ .map(|_| ScalarField::rand(&mut rng))
+ .collect::>();
-You can run the following commands in the root directory of the project to compile the metal library for a given OS:
-```sh
-# for macOS
-bash mopro-msm/src/msm/metal/compile_metal.sh
+ let metal_msm_result = metal_variable_base_msm(&bases, &scalars).unwrap();
+ let arkworks_msm_result = G::msm(&bases, &scalars).unwrap();
-# for iphoneOS
-bash mopro-msm/src/msm/metal/compile_metal_iphone.sh
+ assert_eq!(metal_msm_result, arkworks_msm_result); // the result is the same
+ }
+}
```
+
+## Benchmark
+
+Benchmarking on BN254 curve ran on a MacBook Air with M3 chips, with test case setup time excluded.
+
+
+
+
+ | Scheme |
+ Input Size (ms) |
+
+
+ | 212 |
+ 214 |
+ 216 |
+ 218 |
+ 220 |
+ 222 |
+ 224 |
+
+
+
+
+ Arkworks v0.4.x (CPU, Baseline) |
+ 6 |
+ 19 |
+ 69 |
+ 245 |
+ 942 |
+ 3,319 |
+ 14,061 |
+
+
+ Metal MSM v0.1.0 (GPU) |
+ 143 (-23.8x) |
+ 273 (-14.4x) |
+ 1,730 (-25.1x) |
+ 10,277 (-41.9x) |
+ 41,019 (-43.5x) |
+ 555,877 (-167.5x) |
+ N/A |
+
+
+ Metal MSM v0.2.0 (GPU) |
+ 134 (-22.3x) |
+ 124 (-6.5x) |
+ 253 (-3.7x) |
+ 678 (-2.8x) |
+ 1,702 (-1.8x) |
+ 5,390 (-1.6x) |
+ 22,241 (-1.6x) |
+
+
+ ICME WebGPU MSM (GPU) |
+ N/A |
+ N/A |
+ 2,719 (-39.4x) |
+ 5,418 (-22.1x) |
+ 17,475 (-18.6x) |
+ N/A |
+ N/A |
+
+
+ ICICLE-Metal v3.8.0 (GPU) |
+ 59 (-9.8x) |
+ 54 (-2.8x) |
+ 89 (-1.3x) |
+ 149 (+1.6x) |
+ 421 (+2.2x) |
+ 1,288 (+2.6x) |
+ 4,945 (+2.8x) |
+
+
+
+
+> side note:
+> - for ICME WebGPU MSM, input size 2^12 causes M3 chip machines to crash; any sizes not listed on the project’s GitHub page are shown as "N/A"
+> - for Metal MSM v0.1.0, the 2^24 benchmark was abandoned because it exceeded practical runtime
+
+## Profiling summary (v1 vs v2)
+
+Environment: M1 Pro, macOS 15.2, curve `ark_bn254`, dataset 2^20 unless stated. Medians of 5 runs.
+
+### v2 → v1
+
+| metric | v1[^1] | v2[^2] | gain |
+|---|---|---|---|
+| end-to-end latency | 10.3 s | **0.42 s** | **×24** |
+| GPU occupancy | 32 % | 76 % | +44 pp |
+| CPU share | 19 % | **<3 %** | –16 pp |
+| peak VRAM | 1.6 GB | **220 MB** | –7.3× |
+
+Key changes:
+
+* single sparse-matrix kernel eliminates most launches and memory thrash
+* CSR buckets keep data on-device → near-zero host↔GPU traffic
+* on-GPU radix sort makes preprocessing parallel
+
+## Future
+
+### Technical Improvements
+- **Modern Dependencies**: Update to `objc2` and `objc2-metal` ([objc2](https://github.com/madsmtm/objc2))
+- **Metal 4**: Adopt latest [Metal 4](https://developer.apple.com/metal/whats-new/) features
+- **Refactor with SIMD in mind**:
+ - Instruction-level parallelism using vector types for faster FMA within SIMD groups
+ - Memory coalescing to increase locality (e.g., structure of array instead of array of structure)
+ - Optimized input reading patterns (e.g. `[X_i || Y_i]_0^{n-1}` instead of separate arrays)
+ - Latency hiding and occupancy fine-tuning
+ - Minimize thread divergence
+
+### Algorithm & Integration
+- **CPU-GPU Hybrid**: Research interleaving with CPU MSM crate and update to `arkworks 0.5`
+- **Advanced Algorithms**:
+ - Elastic MSM [[ZHY+24](https://eprint.iacr.org/2024/057.pdf)] implementation
+ - Faster modular reduction with LogJump ([article by Wei Jie](https://kohweijie.com/articles/25/logjumps.html), [Barret-Montgomery](https://hackmd.io/@Ingonyama/Barret-Montgomery))
+
+### Platform Expansion
+- **Cross-platform**: WGSL support with native execution environment
+- **Crypto Math Library**: Maintain a Metal/WebGPU crypto math library
+
+## Community
+
+- X account:
+- Telegram group:
+
+## Acknowledgements
+
+This work was initially sponsored by a joint grant from [PSE](https://pse.dev/) and [0xPARC](https://0xparc.org/). It is currently incubated by PSE.
+
+[^1]: https://hackmd.io/@yaroslav-ya/rJkpqc_Nke
+[^2]: https://hackmd.io/@yaroslav-ya/HyFA7XAQll
\ No newline at end of file
diff --git a/example-app/Cargo.lock b/example-app/Cargo.lock
index 74f318ac..4fe8f306 100644
--- a/example-app/Cargo.lock
+++ b/example-app/Cargo.lock
@@ -1348,7 +1348,7 @@ dependencies = [
[[package]]
name = "mopro-msm"
-version = "0.1.0"
+version = "0.2.0"
dependencies = [
"ark-bn254 0.4.0",
"ark-crypto-primitives 0.4.0",
diff --git a/example-app/Config.toml b/example-app/Config.toml
index 8b3aca31..0c4e64a8 100644
--- a/example-app/Config.toml
+++ b/example-app/Config.toml
@@ -1,8 +1,8 @@
target_adapters = ["circom"]
target_platforms = ["ios"]
ios = [
+ "x86_64-apple-ios",
"aarch64-apple-ios-sim",
"aarch64-apple-ios",
- "x86_64-apple-ios",
]
android = []
diff --git a/example-app/MoproiOSBindings/MoproBindings.xcframework/ios-arm64/libexample_app.a b/example-app/MoproiOSBindings/MoproBindings.xcframework/ios-arm64/libexample_app.a
index b17958a7..2b9093c3 100644
Binary files a/example-app/MoproiOSBindings/MoproBindings.xcframework/ios-arm64/libexample_app.a and b/example-app/MoproiOSBindings/MoproBindings.xcframework/ios-arm64/libexample_app.a differ
diff --git a/example-app/MoproiOSBindings/MoproBindings.xcframework/ios-arm64_x86_64-simulator/libexample_app.a b/example-app/MoproiOSBindings/MoproBindings.xcframework/ios-arm64_x86_64-simulator/libexample_app.a
index ee032428..c141475c 100644
Binary files a/example-app/MoproiOSBindings/MoproBindings.xcframework/ios-arm64_x86_64-simulator/libexample_app.a and b/example-app/MoproiOSBindings/MoproBindings.xcframework/ios-arm64_x86_64-simulator/libexample_app.a differ
diff --git a/example-app/ios/MoproiOSBindings/MoproBindings.xcframework/ios-arm64/libexample_app.a b/example-app/ios/MoproiOSBindings/MoproBindings.xcframework/ios-arm64/libexample_app.a
index b17958a7..2b9093c3 100644
Binary files a/example-app/ios/MoproiOSBindings/MoproBindings.xcframework/ios-arm64/libexample_app.a and b/example-app/ios/MoproiOSBindings/MoproBindings.xcframework/ios-arm64/libexample_app.a differ
diff --git a/example-app/ios/MoproiOSBindings/MoproBindings.xcframework/ios-arm64_x86_64-simulator/libexample_app.a b/example-app/ios/MoproiOSBindings/MoproBindings.xcframework/ios-arm64_x86_64-simulator/libexample_app.a
index ee032428..c141475c 100644
Binary files a/example-app/ios/MoproiOSBindings/MoproBindings.xcframework/ios-arm64_x86_64-simulator/libexample_app.a and b/example-app/ios/MoproiOSBindings/MoproBindings.xcframework/ios-arm64_x86_64-simulator/libexample_app.a differ
diff --git a/mopro-msm/.gitignore b/mopro-msm/.gitignore
index 4d30bfb7..30229109 100644
--- a/mopro-msm/.gitignore
+++ b/mopro-msm/.gitignore
@@ -13,12 +13,6 @@ Cargo.lock
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
-# GPU exploration - preprocessed vectors
-src/middleware/gpu_explorations/utils/vectors
-
-# GPU exploration - proptest generated files
-proptest-regressions
-
# Metal shader intermediate files and libraries
src/msm/metal_msm/shader/**/*.ir
src/msm/metal_msm/shader/**/*.lib
diff --git a/mopro-msm/Cargo.toml b/mopro-msm/Cargo.toml
index 14a5b738..c591253f 100644
--- a/mopro-msm/Cargo.toml
+++ b/mopro-msm/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "mopro-msm"
-version = "0.1.0"
+version = "0.2.0"
edition = "2021"
build = "build.rs"
diff --git a/mopro-msm/README.md b/mopro-msm/README.md
deleted file mode 100644
index bd4338bc..00000000
--- a/mopro-msm/README.md
+++ /dev/null
@@ -1,22 +0,0 @@
-# mopro-msm-benchmarks
-
-Core crate for outputting multi-scalar multiplication integration and implementations
-
-## Catagories
-
-### GPU-based (metal) msm
-
-* window-wise msm
-* bucket-wise msm
-
-### CPU-based msm
-
-* precomputation + window-wise msm
-* bucket-wise msm
-* arkworks pippenger
-* trapdoorTech (integrated from trapdoorTech)
-
-## Results on MacOs
-
-
-## Results on IOS
diff --git a/mopro-msm/benchmark_results/bucket_wise_msm_benchmark.txt b/mopro-msm/benchmark_results/bucket_wise_msm_benchmark.txt
deleted file mode 100644
index e1ed5e2a..00000000
--- a/mopro-msm/benchmark_results/bucket_wise_msm_benchmark.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-msm_size,num_msm,avg_processing_time(ms)
-16,5,563.822834
-18,5,1615.4173584
-20,5,9367.2433832
diff --git a/mopro-msm/benchmark_results/metal_msm_benchmark.txt b/mopro-msm/benchmark_results/metal_msm_benchmark.txt
deleted file mode 100644
index d2635a57..00000000
--- a/mopro-msm/benchmark_results/metal_msm_benchmark.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-msm_size,num_msm,avg_processing_time(ms)
-8,10,76.07561229999999
-12,10,856.7078458
-16,10,3936.7850999
-18,10,26981.7295124
diff --git a/mopro-msm/build.rs b/mopro-msm/build.rs
index 7c23a2fe..14c65d12 100644
--- a/mopro-msm/build.rs
+++ b/mopro-msm/build.rs
@@ -68,12 +68,10 @@ fn compile_shaders() -> std::io::Result<()> {
}
println!(
- "cargo:warning=Found {} Metal shaders to compile",
- metal_paths.len()
+ "cargo:warning=Found {} Metal shaders to compile from {}",
+ metal_paths.len(),
+ shader_root.to_str().unwrap()
);
- for path in &metal_paths {
- println!("cargo:warning=Including shader: {}", path.display());
- }
// Combine selected kernels into one .metal file
let combined = shader_out_dir.join("msm_combined.metal");
@@ -86,24 +84,46 @@ fn compile_shaders() -> std::io::Result<()> {
}
fs::write(&combined, &combined_src)?;
- // Compile combined source to AIR
+ // Determine SDK target
let target = env::var("TARGET").unwrap_or_default();
let sdk = if target.contains("apple-ios") {
"iphoneos"
} else {
"macosx"
};
+
+ // Only detect Metal version if we're not in CI
+ let (metal_std, enable_logging) = if env::var("CI").is_err() {
+ let (metal_version, metal_std, enable_logging) = detect_metal_version(sdk)?;
+ println!("cargo:warning=Detected Metal version: {}", metal_version);
+ (metal_std, enable_logging)
+ } else {
+ println!("cargo:warning=Running in CI - using safe Metal 3.0 standard without logging");
+ ("metal3.0".to_string(), false)
+ };
+
+ // Compile combined source to AIR
let air = shader_out_dir.join("msm.air");
+ let metal_std_arg = format!("-std={}", metal_std);
+ let mut metal_args = vec![
+ "-sdk",
+ sdk,
+ "metal",
+ &metal_std_arg,
+ "-c",
+ combined.to_str().unwrap(),
+ "-o",
+ air.to_str().unwrap(),
+ ];
+
+ // Add logging flag only if enabled (which only happens when not in CI and Metal >= 3.2)
+ if enable_logging {
+ metal_args.insert(metal_args.len() - 3, "-fmetal-enable-logging");
+ println!("cargo:warning=Enabling Metal logging");
+ }
+
let status = Command::new("xcrun")
- .args(&[
- "-sdk",
- sdk,
- "metal",
- "-c",
- combined.to_str().unwrap(),
- "-o",
- air.to_str().unwrap(),
- ])
+ .args(&metal_args)
.status()
.expect("Failed to invoke metal");
if !status.success() {
@@ -137,6 +157,92 @@ fn compile_shaders() -> std::io::Result<()> {
Ok(())
}
+fn detect_metal_version(sdk: &str) -> std::io::Result<(String, String, bool)> {
+ // Use OS version to determine Metal version support
+ let os_version = get_os_version();
+ let (major, minor) = determine_metal_version_from_os(&os_version);
+
+ // Determine Metal standard and whether to enable logging
+ let (metal_std, enable_logging) = if major > 3 || (major == 3 && minor >= 2) {
+ ("metal3.2".to_string(), true)
+ } else if major == 3 && minor >= 1 {
+ ("metal3.1".to_string(), false)
+ } else if major == 3 {
+ ("metal3.0".to_string(), false)
+ } else if major == 2 && minor >= 4 {
+ // For Metal 2.x, we need platform-specific prefixes
+ let platform_prefix = if sdk == "iphoneos" { "ios" } else { "macos" };
+ (format!("{}-metal2.4", platform_prefix), false)
+ } else {
+ // For Metal 2.x, we need platform-specific prefixes
+ let platform_prefix = if sdk == "iphoneos" { "ios" } else { "macos" };
+ (format!("{}-metal2.3", platform_prefix), false)
+ };
+
+ let version_str = format!("{}.{}", major, minor);
+ println!(
+ "cargo:warning=Detected OS version: {} -> Metal version: {}",
+ os_version, version_str
+ );
+ Ok((version_str, metal_std, enable_logging))
+}
+
+fn get_os_version() -> String {
+ // Get OS version using system APIs
+ #[cfg(target_os = "macos")]
+ {
+ use std::process::Command;
+ if let Ok(output) = Command::new("sw_vers").arg("-productVersion").output() {
+ if output.status.success() {
+ return String::from_utf8_lossy(&output.stdout).trim().to_string();
+ }
+ }
+ }
+
+ #[cfg(target_os = "ios")]
+ {
+ // For iOS, we can assume modern Metal support
+ return "16.0".to_string();
+ }
+
+ // Fallback
+ "10.15".to_string()
+}
+
+fn determine_metal_version_from_os(os_version: &str) -> (u32, u32) {
+ let parts: Vec<&str> = os_version.split('.').collect();
+ if parts.len() >= 2 {
+ if let (Ok(major), Ok(minor)) = (parts[0].parse::(), parts[1].parse::()) {
+ // macOS to Metal version mapping
+ #[cfg(target_os = "macos")]
+ {
+ if major >= 14 || (major == 13) {
+ return (3, 2); // macOS 13.0+ supports Metal 3.2
+ } else if major >= 12 || (major == 11) {
+ return (3, 0); // macOS 11.0+ supports Metal 3.0
+ } else if major >= 11 || (major == 10 && minor >= 15) {
+ return (2, 4); // macOS 10.15+ supports Metal 2.4
+ }
+ }
+
+ // iOS to Metal version mapping
+ #[cfg(target_os = "ios")]
+ {
+ if major >= 16 {
+ return (3, 2); // iOS 16+ supports Metal 3.2
+ } else if major >= 14 {
+ return (3, 0); // iOS 14+ supports Metal 3.0
+ } else if major >= 13 {
+ return (2, 4); // iOS 13+ supports Metal 2.4
+ }
+ }
+ }
+ }
+
+ // Default fallback
+ (2, 3)
+}
+
fn visit_dirs(dir: &Path, paths: &mut Vec) -> std::io::Result<()> {
if dir.is_dir() {
for entry in fs::read_dir(dir)? {
diff --git a/mopro-msm/src/msm/bucket_wise_msm.rs b/mopro-msm/src/msm/bucket_wise_msm.rs
deleted file mode 100644
index 00282398..00000000
--- a/mopro-msm/src/msm/bucket_wise_msm.rs
+++ /dev/null
@@ -1,365 +0,0 @@
-use ark_bn254::{Fr as ScalarField, G1Projective as G};
-use ark_ec::VariableBaseMSM;
-use ark_ff::{BigInteger, PrimeField};
-use ark_std::{self, cfg_into_iter, One};
-use std::sync::Mutex;
-use std::time::{Duration, Instant};
-
-use crate::{
- msm::utils::{benchmark::BenchmarkResult, preprocess},
- MoproError,
-};
-use rayon::prelude::*;
-
-// Helper function for getting the windows size
-fn ln_without_floats(a: usize) -> usize {
- // log2(a) * ln(2)
- (ark_std::log2(a) * 69 / 100) as usize
-}
-
-/// use bucket-wise accumulation in msm
-fn bucket_wise_msm(
- bases: &[V::MulBase],
- scalars: &[V::ScalarField],
-) -> Result {
- let bigints = cfg_into_iter!(scalars)
- .map(|s| s.into_bigint())
- .collect::>();
- let instance_size = ark_std::cmp::min(bases.len(), bigints.len());
- let scalars = &bigints[..instance_size];
- let bases = &bases[..instance_size];
-
- let c = if instance_size < 32 {
- 3
- } else {
- ln_without_floats(instance_size) + 2
- };
-
- let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
- let one = V::ScalarField::one().into_bigint();
-
- let zero = V::zero();
- let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();
- let num_window = window_starts.len();
- let bucket_len = (1 << c) - 1;
-
- // prepare buckets and points indices (no need to sort in this case)
- let prepare_start = Instant::now();
- let indices_lists = Mutex::new(vec![(0 as usize, 0 as usize); instance_size * num_window]);
- ark_std::cfg_into_iter!(scalars)
- .enumerate()
- .for_each(|(point_idx, each_scalar)| {
- if *each_scalar == one {
- return;
- }
-
- for i in 0..num_window {
- let w_start = window_starts[i];
- let mut scalar = *each_scalar;
- scalar.divn(w_start as u32);
- let scalar = scalar.as_ref()[0] % (1 << c);
- if scalar != 0 {
- let bucket_idx = i * bucket_len + (scalar as usize) - 1;
- let mut indices_lists = indices_lists.lock().unwrap();
- indices_lists[point_idx * num_window + i] = (bucket_idx, point_idx);
- }
- }
- });
- println!(
- "Prepare buckets indices time: {:?}",
- prepare_start.elapsed()
- );
-
- // sort the buckets_indices parallelly
- let sort_start = Instant::now();
- let mut indices_lists = indices_lists.lock().unwrap();
- // indices_lists.par_sort_unstable_by_key(|a| a.0); // unstable version is faster
- indices_lists.par_sort_unstable_by(|a, b| {
- if a.0 == b.0 {
- a.1.cmp(&b.1)
- } else {
- a.0.cmp(&b.0)
- }
- });
-
- // remove the first few (0, 0) indices
- let mut k = 0;
- while (indices_lists[k].0 == 0) && (indices_lists[k].1 == 0) {
- k += 1;
- }
- indices_lists.par_drain(0..k);
- println!("Sort buckets indices time: {:?}", sort_start.elapsed());
-
- // find the start and end of each bucket
- let total_buckets_size = num_window * bucket_len;
- let mut bucket_start = vec![0; total_buckets_size];
- let mut bucket_end = vec![0; total_buckets_size];
- let mut prev_bucket_idx = 0;
- let last_idx = indices_lists.len() - 1;
- for (idx, (bucket_idx, _)) in indices_lists.iter().enumerate() {
- if idx == 0 {
- prev_bucket_idx = *bucket_idx;
- } else {
- if *bucket_idx != prev_bucket_idx {
- bucket_end[prev_bucket_idx] = idx;
- bucket_start[*bucket_idx] = idx;
- prev_bucket_idx = *bucket_idx;
- }
- // add the last idx to the end
- if idx == last_idx {
- bucket_end[*bucket_idx] = idx + 1;
- }
- }
- }
-
- // build an active bucket list to reduce meaning initialization of threads
- let active_buckets: Vec<_> = ark_std::cfg_into_iter!(0..total_buckets_size)
- .filter(|i| bucket_start[*i] != 0 || bucket_end[*i] != 0)
- .collect();
-
- // do the bucket-wise accumulation
- let accumulation_start = Instant::now();
- let buckets = Mutex::new(vec![V::zero(); total_buckets_size]);
- ark_std::cfg_into_iter!(active_buckets).for_each(|bucket_idx| {
- let mut buckets = buckets.lock().unwrap();
- for i in bucket_start[bucket_idx]..bucket_end[bucket_idx] {
- buckets[bucket_idx] += bases[indices_lists[i].1];
- }
- });
- println!(
- "Accumulate buckets time: {:?}",
- accumulation_start.elapsed()
- );
-
- // do window-wise reduction
- let reduction_start = Instant::now();
- let window_sums: Vec<_> = {
- let buckets_vec: Vec<_> =
- ark_std::cfg_into_iter!(buckets.lock().unwrap().clone()).collect();
- buckets_vec
- .par_chunks(bucket_len)
- .enumerate()
- .map(|(window_idx, bucket)| {
- let mut res = zero;
- if window_idx == 0 {
- for i in 0..instance_size {
- if scalars[i] == one {
- res += &bucket[i];
- }
- }
- }
- let mut running_sum = zero;
- bucket.into_iter().rev().for_each(|b| {
- running_sum += b;
- res += &running_sum;
- });
- res
- })
- .collect()
- };
- println!("Sum reduction time: {:?}", reduction_start.elapsed());
-
- // We store the sum for the lowest window.
- let lowest = *window_sums.first().unwrap();
-
- // We're traversing windows from high to low.
- Ok(lowest
- + &window_sums[1..]
- .iter()
- .rev()
- .fold(zero, |mut total, sum_i| {
- total += sum_i;
- for _ in 0..c {
- total.double_in_place();
- }
- total
- }))
-}
-
-pub fn benchmark_msm(
- instances: I,
- iterations: u32,
-) -> Result, preprocess::HarnessError>
-where
- I: Iterator- ,
-{
- let mut instance_durations = Vec::new();
-
- for instance in instances {
- let points = &instance.0;
- // map each scalar to a ScalarField
- let scalars = &instance
- .1
- .iter()
- .map(|s| ScalarField::new(*s))
- .collect::>();
- let mut instance_total_duration = Duration::ZERO;
- for _i in 0..iterations {
- let start = Instant::now();
- let _ = bucket_wise_msm::(&points[..], &scalars[..]).unwrap();
- instance_total_duration += start.elapsed();
- }
- let instance_avg_duration = instance_total_duration / iterations;
-
- println!(
- "Average time to execute MSM with {} points and {} scalars in {} iterations is: {:?}",
- points.len(),
- scalars.len(),
- iterations,
- instance_avg_duration,
- );
- instance_durations.push(instance_avg_duration);
- }
- Ok(instance_durations)
-}
-
-pub fn run_benchmark(
- instance_size: u32,
- num_instance: u32,
- utils_dir: &str,
-) -> Result {
- // Check if the vectors have been generated
- match preprocess::FileInputIterator::open(&utils_dir) {
- Ok(_) => {
- println!("Vectors already generated");
- }
- Err(_) => {
- preprocess::gen_vectors(instance_size, num_instance, &utils_dir);
- }
- }
-
- let benchmark_data = preprocess::FileInputIterator::open(&utils_dir).unwrap();
- let instance_durations = benchmark_msm(benchmark_data, 1).unwrap();
- // in milliseconds
- let avg_processing_time: f64 = instance_durations
- .iter()
- .map(|d| d.as_secs_f64() * 1000.0)
- .sum::()
- / instance_durations.len() as f64;
-
- println!("Done running benchmark.");
- Ok(BenchmarkResult {
- instance_size: instance_size,
- num_instance: num_instance,
- avg_processing_time: avg_processing_time,
- })
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use ark_serialize::Write;
- use std::fs::File;
-
- const INSTANCE_SIZE: u32 = 16;
- const NUM_INSTANCE: u32 = 1;
- const UTILSPATH: &str = "src/msm/utils/vectors";
- const BENCHMARKSPATH: &str = "benchmark_results";
-
- #[test]
- fn test_msm_correctness_medium_sample() {
- let dir = format!("{}/{}/{}x{}", preprocess::get_root_path(), UTILSPATH, 8, 5);
-
- // Check if the vectors have been generated
- match preprocess::FileInputIterator::open(&dir) {
- Ok(_) => {
- println!("Vectors already generated");
- }
- Err(_) => {
- preprocess::gen_vectors(INSTANCE_SIZE, NUM_INSTANCE, &dir);
- }
- }
-
- let instances = preprocess::FileInputIterator::open(&dir).unwrap();
-
- for (i, instance) in instances.enumerate() {
- let points = &instance.0;
- // map each scalar to a ScalarField
- let scalars = &instance
- .1
- .iter()
- .map(|s| ScalarField::new(*s))
- .collect::>();
- let arkworks_msm = G::msm(&points[..], &scalars[..]).unwrap();
- let msm = bucket_wise_msm::(&points[..], &scalars[..]).unwrap();
- assert_eq!(msm, arkworks_msm, "This msm is wrongly computed");
- println!(
- "(pass) {}th instance of size 2^{} is correctly computed",
- i, 8
- );
- }
- }
-
- #[test]
- fn test_benchmark_msm() {
- let dir = format!(
- "{}/{}/{}x{}",
- preprocess::get_root_path(),
- UTILSPATH,
- INSTANCE_SIZE,
- NUM_INSTANCE
- );
-
- // Check if the vectors have been generated
- match preprocess::FileInputIterator::open(&dir) {
- Ok(_) => {
- println!("Vectors already generated");
- }
- Err(_) => {
- preprocess::gen_vectors(INSTANCE_SIZE, NUM_INSTANCE, &dir);
- }
- }
-
- let benchmark_data = preprocess::FileInputIterator::open(&dir).unwrap();
- let result = benchmark_msm(benchmark_data, 1);
- println!("Done running benchmark: {:?}", result);
- }
-
- #[test]
- fn test_run_benchmark() {
- let utils_path = format!(
- "{}/{}/{}x{}",
- preprocess::get_root_path(),
- &UTILSPATH,
- INSTANCE_SIZE,
- NUM_INSTANCE
- );
- let result = run_benchmark(INSTANCE_SIZE, NUM_INSTANCE, &utils_path).unwrap();
- println!("Benchmark result: {:#?}", result);
- }
-
- #[test]
- fn test_run_multi_benchmarks() {
- let output_path = format!(
- "{}/{}/{}_benchmark.txt",
- preprocess::get_root_path(),
- &BENCHMARKSPATH,
- "bucket_wise_msm"
- );
- let mut output_file = File::create(output_path).expect("output file creation failed");
- writeln!(output_file, "msm_size,num_msm,avg_processing_time(ms)").unwrap();
-
- let instance_size = vec![16, 18, 20, 22, 24, 26];
- let num_instance = vec![5];
- for size in &instance_size {
- for num in &num_instance {
- let utils_path = format!(
- "{}/{}/{}x{}",
- preprocess::get_root_path(),
- &UTILSPATH,
- *size,
- *num
- );
- let result = run_benchmark(*size, *num, &utils_path).unwrap();
- println!("{}x{} result: {:#?}", *size, *num, result);
- writeln!(
- output_file,
- "{},{},{}",
- result.instance_size, result.num_instance, result.avg_processing_time
- )
- .unwrap();
- }
- }
- }
-}
diff --git a/mopro-msm/src/msm/metal/abstraction/errors.rs b/mopro-msm/src/msm/metal/abstraction/errors.rs
deleted file mode 100644
index 9a1e935c..00000000
--- a/mopro-msm/src/msm/metal/abstraction/errors.rs
+++ /dev/null
@@ -1,19 +0,0 @@
-use thiserror::Error;
-
-#[derive(Debug, Error)]
-pub enum MetalError {
- #[error("Couldn't find a system default device for Metal")]
- DeviceNotFound(),
- #[error("Couldn't create a new Metal library: {0}")]
- LibraryError(String),
- #[error("Couldn't create a new Metal function object: {0}")]
- FunctionError(String),
- #[error("Couldn't create a new Metal compute pipeline: {0}")]
- PipelineError(String),
- #[error("Could not calculate {1} root of unity")]
- RootOfUnityError(String, u64),
- // #[error("Input length is {0}, which is not a power of two")]
- // InputError(usize),
- #[error("Invalid input: {0}")]
- InputError(String),
-}
diff --git a/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs b/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs
deleted file mode 100644
index ebf772b5..00000000
--- a/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs
+++ /dev/null
@@ -1,162 +0,0 @@
-use ark_bn254::Fq;
-use ark_ff::biginteger::{BigInteger, BigInteger256};
-
-use crate::msm::metal::abstraction::mont_reduction;
-
-// implement to_u32_limbs and from_u32_limbs for BigInt<4>
-pub trait ToLimbs {
- fn to_u32_limbs(&self) -> Vec;
- fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec;
-}
-
-pub trait FromLimbs {
- fn from_u32_limbs(limbs: &[u32]) -> Self;
- fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
- fn from_u128(num: u128) -> Self;
- fn from_u32(num: u32) -> Self;
-}
-
-// convert from little endian to big endian
-impl ToLimbs for BigInteger256 {
- fn to_u32_limbs(&self) -> Vec {
- let mut limbs = Vec::new();
- self.to_bytes_be().chunks(8).for_each(|chunk| {
- let high = u32::from_be_bytes(chunk[0..4].try_into().unwrap());
- let low = u32::from_be_bytes(chunk[4..8].try_into().unwrap());
- limbs.push(high);
- limbs.push(low);
- });
- limbs
- }
-
- fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec {
- let mut result = vec![0u32; num_limbs];
- let limb_size = 1u32 << log_limb_size;
- let mask = limb_size - 1;
-
- // Convert to little-endian representation
- let bytes = self.to_bytes_le();
- let mut val = 0u32;
- let mut bits = 0u32;
- let mut limb_idx = 0;
-
- for &byte in bytes.iter() {
- val |= (byte as u32) << bits;
- bits += 8;
-
- while bits >= log_limb_size && limb_idx < num_limbs {
- result[limb_idx] = val & mask;
- val >>= log_limb_size;
- bits -= log_limb_size;
- limb_idx += 1;
- }
- }
-
- // Handle any remaining bits
- if bits > 0 && limb_idx < num_limbs {
- result[limb_idx] = val;
- }
-
- result
- }
-}
-
-// convert from little endian to big endian
-impl ToLimbs for Fq {
- fn to_u32_limbs(&self) -> Vec {
- let mut limbs = Vec::new();
- self.0.to_bytes_be().chunks(8).for_each(|chunk| {
- let high = u32::from_be_bytes(chunk[0..4].try_into().unwrap());
- let low = u32::from_be_bytes(chunk[4..8].try_into().unwrap());
- limbs.push(high);
- limbs.push(low);
- });
- limbs
- }
-
- fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec {
- self.0.to_limbs(num_limbs, log_limb_size)
- }
-}
-
-impl FromLimbs for BigInteger256 {
- // convert from big endian to little endian for metal
- fn from_u32_limbs(limbs: &[u32]) -> Self {
- let mut big_int = [0u64; 4];
- for (i, limb) in limbs.chunks(2).rev().enumerate() {
- let high = u64::from(limb[0]);
- let low = u64::from(limb[1]);
- big_int[i] = (high << 32) | low;
- }
- BigInteger256::new(big_int)
- }
- // provide little endian u128 since arkworks use this value as well
- fn from_u128(num: u128) -> Self {
- let high = (num >> 64) as u64;
- let low = num as u64;
- BigInteger256::new([low, high, 0, 0])
- }
- // provide little endian u32 since arkworks use this value as well
- fn from_u32(num: u32) -> Self {
- BigInteger256::new([num as u64, 0, 0, 0])
- }
-
- fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
- let mut result = [0u64; 4];
- let limb_size = log_limb_size as usize;
- let mut accumulated_bits = 0;
- let mut current_u64 = 0u64;
- let mut result_idx = 0;
-
- for &limb in limbs {
- // Add the current limb at the appropriate position
- current_u64 |= (limb as u64) << accumulated_bits;
- accumulated_bits += limb_size;
-
- // If we've accumulated 64 bits or more, store the result
- while accumulated_bits >= 64 && result_idx < 4 {
- result[result_idx] = current_u64;
- current_u64 = limb as u64 >> (limb_size - (accumulated_bits - 64));
- accumulated_bits -= 64;
- result_idx += 1;
- }
- }
-
- // Handle any remaining bits
- if accumulated_bits > 0 && result_idx < 4 {
- result[result_idx] = current_u64;
- }
-
- BigInteger256::new(result)
- }
-}
-
-impl FromLimbs for Fq {
- // convert from big endian to little endian for metal
- fn from_u32_limbs(limbs: &[u32]) -> Self {
- let mut big_int = [0u64; 4];
- for (i, limb) in limbs.chunks(2).rev().enumerate() {
- let high = u64::from(limb[0]);
- let low = u64::from(limb[1]);
- big_int[i] = (high << 32) | low;
- }
- Fq::new(mont_reduction::raw_reduction(BigInteger256::new(big_int)))
- }
- fn from_u128(num: u128) -> Self {
- let high = (num >> 64) as u64;
- let low = num as u64;
- Fq::new(mont_reduction::raw_reduction(BigInteger256::new([
- low, high, 0, 0,
- ])))
- }
- fn from_u32(num: u32) -> Self {
- Fq::new(mont_reduction::raw_reduction(BigInteger256::new([
- num as u64, 0, 0, 0,
- ])))
- }
-
- fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
- let bigint = BigInteger256::from_limbs(limbs, log_limb_size);
- Fq::new(mont_reduction::raw_reduction(bigint))
- }
-}
diff --git a/mopro-msm/src/msm/metal/abstraction/mod.rs b/mopro-msm/src/msm/metal/abstraction/mod.rs
deleted file mode 100644
index e9d20b6a..00000000
--- a/mopro-msm/src/msm/metal/abstraction/mod.rs
+++ /dev/null
@@ -1,4 +0,0 @@
-pub mod errors;
-pub mod limbs_conversion;
-pub mod mont_reduction;
-pub mod state;
diff --git a/mopro-msm/src/msm/metal/abstraction/mont_reduction.rs b/mopro-msm/src/msm/metal/abstraction/mont_reduction.rs
deleted file mode 100644
index db0069ff..00000000
--- a/mopro-msm/src/msm/metal/abstraction/mont_reduction.rs
+++ /dev/null
@@ -1,40 +0,0 @@
-use ark_bn254::FqConfig;
-use ark_ff::{
- biginteger::{arithmetic as fa, BigInt},
- fields::models::{MontBackend, MontConfig},
- Fp,
-};
-
-// Reference: https://github.com/arkworks-rs/algebra/blob/master/ff/src/fields/models/fp/montgomery_backend.rs#L373-L389
-const N: usize = 4;
-pub fn into_bigint(a: Fp, N>) -> BigInt {
- let a = a.0;
- raw_reduction(a)
-}
-
-pub fn raw_reduction(a: BigInt) -> BigInt {
- let mut r = a.0; // parse into [u64; N]
-
- // Montgomery Reduction
- for i in 0..N {
- let k = r[i].wrapping_mul(>::INV);
- let mut carry = 0;
-
- fa::mac_with_carry(
- r[i],
- k,
- >::MODULUS.0[0],
- &mut carry,
- );
- for j in 1..N {
- r[(j + i) % N] = fa::mac_with_carry(
- r[(j + i) % N],
- k,
- >::MODULUS.0[j],
- &mut carry,
- );
- }
- r[i % N] = carry;
- }
- BigInt::new(r)
-}
diff --git a/mopro-msm/src/msm/metal/abstraction/state.rs b/mopro-msm/src/msm/metal/abstraction/state.rs
deleted file mode 100644
index 024aa803..00000000
--- a/mopro-msm/src/msm/metal/abstraction/state.rs
+++ /dev/null
@@ -1,124 +0,0 @@
-use metal::{ComputeCommandEncoderRef, MTLResourceOptions};
-
-use crate::msm::metal::abstraction::errors::MetalError;
-
-use core::{ffi, mem};
-use std::{env, fs, path::Path};
-
-/// Structure for abstracting basic calls to a Metal device and saving the state. Used for
-/// implementing GPU parallel computations in Apple machines.
-pub struct MetalState {
- pub device: metal::Device,
- pub library: metal::Library,
- pub queue: metal::CommandQueue,
-}
-
-impl MetalState {
- /// Creates a new Metal state with an optional `device` (GPU). If `None` is passed then it will use
- /// the system's default.
- pub fn new(device: Option) -> Result {
- let device: metal::Device =
- device.unwrap_or(metal::Device::system_default().ok_or(MetalError::DeviceNotFound())?);
-
- let metallib_path = Path::new(env!("OUT_DIR")).join("msm.metallib");
-
- let lib_data = fs::read(metallib_path)
- .expect(format!("Missing metal library on the path {}", env!("OUT_DIR")).as_str());
-
- let library = device
- .new_library_with_data(&lib_data)
- .map_err(MetalError::LibraryError)?;
- let queue = device.new_command_queue();
-
- Ok(Self {
- device,
- library,
- queue,
- })
- }
-
- /// Creates a pipeline based on a compute function `kernel` which needs to exist in the state's
- /// library. A pipeline is used for issuing commands to the GPU through command buffers,
- /// executing the `kernel` function.
- pub fn setup_pipeline(
- &self,
- kernel_name: &str,
- ) -> Result {
- let kernel = self
- .library
- .get_function(kernel_name, None)
- .map_err(MetalError::FunctionError)?;
-
- let pipeline = self
- .device
- .new_compute_pipeline_state_with_function(&kernel)
- .map_err(MetalError::PipelineError)?;
-
- Ok(pipeline)
- }
-
- /// Allocates `length` bytes of shared memory between CPU and the device (GPU).
- pub fn alloc_buffer(&self, length: usize) -> metal::Buffer {
- let size = mem::size_of::();
-
- self.device.new_buffer(
- (length * size) as u64,
- MTLResourceOptions::StorageModeShared, // TODO: use managed mode
- )
- }
-
- /// Allocates `data` in a buffer of shared memory between CPU and the device (GPU).
- pub fn alloc_buffer_data(&self, data: &[T]) -> metal::Buffer {
- let size = mem::size_of::();
-
- self.device.new_buffer_with_data(
- data.as_ptr() as *const ffi::c_void,
- (data.len() * size) as u64,
- MTLResourceOptions::StorageModeShared, // TODO: use managed mode
- )
- }
-
- pub fn set_bytes(index: usize, data: &[T], encoder: &ComputeCommandEncoderRef) {
- let size = mem::size_of::();
-
- encoder.set_bytes(
- index as u64,
- (data.len() * size) as u64,
- data.as_ptr() as *const ffi::c_void,
- );
- }
-
- /// Creates a command buffer and a compute encoder in a pipeline, optionally issuing `buffers`
- /// to it.
- pub fn setup_command(
- &self,
- pipeline: &metal::ComputePipelineState,
- buffers: Option<&[(u64, &metal::Buffer)]>,
- ) -> (&metal::CommandBufferRef, &metal::ComputeCommandEncoderRef) {
- let command_buffer = self.queue.new_command_buffer();
- let command_encoder = command_buffer.new_compute_command_encoder();
- command_encoder.set_compute_pipeline_state(pipeline);
-
- if let Some(buffers) = buffers {
- for (i, buffer) in buffers.iter() {
- command_encoder.set_buffer(*i, Some(buffer), 0);
- }
- }
-
- (command_buffer, command_encoder)
- }
-
- /// Returns a vector of a copy of the data that `buffer` holds, interpreting it into a specific
- /// type `T`.
- ///
- /// BEWARE: this function uses an unsafe function for retrieveing the data, if the buffer's
- /// contents don't match the specified `T`, expect undefined behaviour. Always make sure the
- /// buffer you are retreiving from holds data of type `T`.
- pub fn retrieve_contents(buffer: &metal::Buffer) -> Vec {
- let ptr = buffer.contents() as *const T;
- let len = buffer.length() as usize / mem::size_of::();
- let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
-
- slice.to_vec()
- }
-}
diff --git a/mopro-msm/src/msm/metal/compile_metal_iphone.sh b/mopro-msm/src/msm/metal/compile_metal_iphone.sh
deleted file mode 100755
index abf6ce25..00000000
--- a/mopro-msm/src/msm/metal/compile_metal_iphone.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-# every kernel is self contained i.e. is its own crate. Simply `cd` into a kernel's directory and run
-# the following, which compiles our shader to an intermediate representation using the metal utility
-xcrun -sdk iphoneos metal -c ./mopro-msm/src/msm/metal/shader/all.metal -o ./mopro-msm/src/msm/metal/shader/all.air
-
-# next, compile the .air file to generate a .metallib file - which I believe is LLVM IR (need confirmation)
-xcrun -sdk iphoneos metallib ./mopro-msm/src/msm/metal/shader/all.air -o ./mopro-msm/src/msm/metal/shader/msm.metallib
-
-# finally, clean the redundant .air file
-rm -f ./mopro-msm/src/msm/metal/shader/all.air
\ No newline at end of file
diff --git a/mopro-msm/src/msm/metal/mod.rs b/mopro-msm/src/msm/metal/mod.rs
deleted file mode 100644
index 98cc2a73..00000000
--- a/mopro-msm/src/msm/metal/mod.rs
+++ /dev/null
@@ -1,3 +0,0 @@
-pub mod abstraction;
-pub mod msm;
-pub mod tests;
diff --git a/mopro-msm/src/msm/metal/msm.rs b/mopro-msm/src/msm/metal/msm.rs
deleted file mode 100644
index 206cda9f..00000000
--- a/mopro-msm/src/msm/metal/msm.rs
+++ /dev/null
@@ -1,606 +0,0 @@
-use ark_bn254::{Fq, Fr as ScalarField, G1Affine as GAffine, G1Projective as G};
-use ark_ec::AffineRepr;
-use ark_ff::PrimeField;
-use ark_std::{cfg_into_iter, vec::Vec};
-// For benchmarking
-use std::time::{Duration, Instant};
-
-use crate::msm::metal::abstraction::{
- errors::MetalError,
- limbs_conversion::{FromLimbs, ToLimbs},
- state::*,
-};
-use crate::msm::utils::{benchmark::BenchmarkResult, preprocess};
-
-use metal::*;
-use objc::rc::autoreleasepool;
-use rayon::prelude::*;
-
-pub struct MetalMsmData {
- pub window_size_buffer: Buffer,
- pub instances_size_buffer: Buffer,
- pub window_starts_buffer: Buffer,
- pub scalar_buffer: Buffer,
- pub base_buffer: Buffer,
- pub num_windows_buffer: Buffer,
- pub buckets_indices_buffer: Buffer,
- pub buckets_matrix_buffer: Buffer,
- pub res_buffer: Buffer,
- pub result_buffer: Buffer,
- // pub debug_buffer: Buffer,
-}
-
-pub struct MetalMsmParams {
- pub instances_size: u32,
- pub buckets_size: u32,
- pub window_size: u32,
- pub num_window: u64,
-}
-
-pub struct MetalMsmPipeline {
- pub init_buckets: ComputePipelineState,
- pub accumulation_and_reduction: ComputePipelineState,
- pub final_accumulation: ComputePipelineState,
- pub prepare_buckets_indices: ComputePipelineState,
- pub bucket_wise_accumulation: ComputePipelineState,
- pub sum_reduction: ComputePipelineState,
-}
-
-pub struct MetalMsmConfig {
- pub state: MetalState,
- pub pipelines: MetalMsmPipeline,
-}
-
-pub struct MetalMsmInstance {
- pub data: MetalMsmData,
- pub params: MetalMsmParams,
-}
-
-// Helper function for getting the windows size
-fn ln_without_floats(a: usize) -> usize {
- // log2(a) * ln(2)
- (ark_std::log2(a) * 69 / 100) as usize
-}
-
-fn sort_buckets_indices(buckets_indices: &mut Vec) -> () {
- // parse the buckets_indices to a Vec<(u32, u32)>
- let mut buckets_indices_pairs: Vec<(u32, u32)> = Vec::new();
- for i in 0..buckets_indices.len() / 2 {
- // skip empty indices (0, 0)
- if buckets_indices[2 * i] == 0 && buckets_indices[2 * i + 1] == 0 {
- continue;
- }
- buckets_indices_pairs.push((buckets_indices[2 * i], buckets_indices[2 * i + 1]));
- }
- // parallel sort the buckets_indices_pairs by the first element
- buckets_indices_pairs.par_sort_by(|a, b| a.0.cmp(&b.0));
-
- // flatten the sorted pairs to a Vec
- buckets_indices.clear();
- for (start, end) in buckets_indices_pairs {
- buckets_indices.push(start);
- buckets_indices.push(end);
- }
-}
-
-pub fn setup_metal_state() -> MetalMsmConfig {
- let state = MetalState::new(None).unwrap();
- let init_buckets = state.setup_pipeline("initialize_buckets").unwrap();
- let accumulation_and_reduction = state
- .setup_pipeline("accumulation_and_reduction_phase")
- .unwrap();
- let final_accumulation = state.setup_pipeline("final_accumulation").unwrap();
-
- // TODO:
- let prepare_buckets_indices = state.setup_pipeline("prepare_buckets_indices").unwrap();
- let bucket_wise_accumulation = state.setup_pipeline("bucket_wise_accumulation").unwrap();
- let sum_reduction = state.setup_pipeline("sum_reduction").unwrap();
- // let make_histogram_uint32 = state.setup_pipeline("make_histogram_uint32").unwrap();
- // let reorder_uint32 = state.setup_pipeline("reorder_uint32").unwrap();
-
- // let make_histogram_uint32_raw = state.library.get_function("reorder_uint32", None).unwrap();
- // let tmp = state.setup_pipeline("reorder_uint32").unwrap();
- // println!("tmp: {:?}", tmp);
- // state.library.function_names().iter().for_each(|name| {
- // println!("Function name: {:?}", name);
- // });
- // let compute_descriptor = ComputePipelineDescriptor::new();
- // compute_descriptor.set_compute_function(Some(&make_histogram_uint32_raw));
- // println!("make_histogram_uint32: {:?}", compute_descriptor.compute_function().unwrap());
- // println!("make_histogram_uint32: {:?}", result);
-
- MetalMsmConfig {
- state,
- pipelines: MetalMsmPipeline {
- init_buckets,
- accumulation_and_reduction,
- final_accumulation,
- prepare_buckets_indices,
- bucket_wise_accumulation,
- sum_reduction,
- },
- }
-}
-
-pub fn encode_instances(
- points: &[GAffine],
- scalars: &[ScalarField],
- config: &mut MetalMsmConfig,
-) -> MetalMsmInstance {
- let modulus_bit_size = ScalarField::MODULUS_BIT_SIZE as usize;
-
- let instances_size = ark_std::cmp::min(points.len(), scalars.len());
- let window_size = if instances_size < 32 {
- 3
- } else {
- ln_without_floats(instances_size) + 2
- };
- let buckets_size = (1 << window_size) - 1;
- let window_starts: Vec = (0..modulus_bit_size).step_by(window_size).collect();
- let num_windows = window_starts.len();
-
- // flatten scalar and base to Vec for GPU usage
- let scalars_limbs = cfg_into_iter!(scalars)
- .map(|s| s.into_bigint().to_u32_limbs())
- .flatten()
- .collect::>();
- let bases_limbs = cfg_into_iter!(points)
- .map(|b| {
- let b = b.into_group();
- b.x.to_u32_limbs()
- .into_iter()
- .chain(b.y.to_u32_limbs())
- .chain(b.z.to_u32_limbs())
- .collect::>()
- })
- .flatten()
- .collect::>();
-
- // store params to GPU shared memory
- let window_size_buffer = config.state.alloc_buffer_data(&[window_size as u32]);
- let instances_size_buffer = config.state.alloc_buffer_data(&[instances_size as u32]);
- let scalar_buffer = config.state.alloc_buffer_data(&scalars_limbs);
- let base_buffer = config.state.alloc_buffer_data(&bases_limbs);
- let num_windows_buffer = config.state.alloc_buffer_data(&[num_windows as u32]);
- let buckets_matrix_buffer = config
- .state
- .alloc_buffer::(buckets_size * num_windows * 8 * 3);
- let res_buffer = config.state.alloc_buffer::(num_windows * 8 * 3);
- let result_buffer = config.state.alloc_buffer::(8 * 3);
- // convert window_starts to u32 to give the exact storage need for GPU
- let window_starts_buffer = config.state.alloc_buffer_data(
- &(window_starts
- .iter()
- .map(|x| *x as u32)
- .collect::>()),
- );
- // prepare bucket_size * num_windows * 2
- let buckets_indices_buffer = config
- .state
- .alloc_buffer::(instances_size * num_windows * 2);
-
- // // debug
- // let debug_buffer = config.state.alloc_buffer::(2048);
-
- MetalMsmInstance {
- data: MetalMsmData {
- window_size_buffer,
- instances_size_buffer,
- window_starts_buffer,
- scalar_buffer,
- base_buffer,
- num_windows_buffer,
- buckets_matrix_buffer,
- buckets_indices_buffer,
- res_buffer,
- result_buffer,
- // debug_buffer,
- },
- params: MetalMsmParams {
- instances_size: instances_size as u32,
- buckets_size: buckets_size as u32,
- window_size: window_size as u32,
- num_window: num_windows as u64,
- },
- }
-}
-
-pub fn exec_metal_commands(
- config: &MetalMsmConfig,
- instance: MetalMsmInstance,
-) -> Result {
- let data = instance.data;
- let params = instance.params;
-
- // Init the pipleline for MSM
- let init_time = Instant::now();
- autoreleasepool(|| {
- let (command_buffer, command_encoder) = config.state.setup_command(
- &config.pipelines.init_buckets,
- Some(&[
- (0, &data.window_size_buffer),
- (1, &data.window_starts_buffer),
- (2, &data.buckets_matrix_buffer),
- ]),
- );
- command_encoder
- .dispatch_thread_groups(MTLSize::new(params.num_window, 1, 1), MTLSize::new(1, 1, 1));
- command_encoder.end_encoding();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- });
- println!("Init buckets time: {:?}", init_time.elapsed());
-
- let prepare_time = Instant::now();
- autoreleasepool(|| {
- let (command_buffer, command_encoder) = config.state.setup_command(
- &config.pipelines.prepare_buckets_indices,
- Some(&[
- (0, &data.window_size_buffer),
- (1, &data.window_starts_buffer),
- (2, &data.num_windows_buffer),
- (3, &data.scalar_buffer),
- (4, &data.buckets_indices_buffer),
- ]),
- );
- command_encoder.dispatch_thread_groups(
- MTLSize::new(params.instances_size as u64, 1, 1),
- MTLSize::new(1, 1, 1),
- );
- command_encoder.end_encoding();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- });
- println!("Prepare buckets indices time: {:?}", prepare_time.elapsed());
-
- // sort the buckets_indices in CPU parallelly
- let sort_start = Instant::now();
- let mut buckets_indices = MetalState::retrieve_contents::(&data.buckets_indices_buffer);
- sort_buckets_indices(&mut buckets_indices);
-
- // send the sorted buckets back to GPU
- let sorted_buckets_indices_buffer = config.state.alloc_buffer_data(&buckets_indices);
- println!("Sort buckets indices time: {:?}", sort_start.elapsed());
-
- // accumulate the buckets_matrix using sorted bucket indices on GPU
- let max_threads_per_group = MTLSize::new(
- config
- .pipelines
- .bucket_wise_accumulation
- .thread_execution_width(),
- config
- .pipelines
- .bucket_wise_accumulation
- .max_total_threads_per_threadgroup()
- / config
- .pipelines
- .bucket_wise_accumulation
- .thread_execution_width(),
- 1,
- );
- let max_thread_size = params.buckets_size as u64 * params.num_window;
- let opt_threadgroups_amount = max_thread_size
- / config
- .pipelines
- .bucket_wise_accumulation
- .max_total_threads_per_threadgroup()
- + 1;
- let opt_threadgroups = MTLSize::new(opt_threadgroups_amount, 1, 1);
- println!(
- "(accumulation) max thread per threadgroup: {:?}",
- max_threads_per_group
- );
- println!("(accumulation) opt threadgroups: {:?}", opt_threadgroups);
-
- let max_thread_size_accu_buffer = config.state.alloc_buffer_data(&[max_thread_size as u32]);
- let bucket_wise_time = Instant::now();
- autoreleasepool(|| {
- let (command_buffer, command_encoder) = config.state.setup_command(
- &config.pipelines.bucket_wise_accumulation,
- Some(&[
- (0, &data.instances_size_buffer),
- (1, &data.num_windows_buffer),
- (2, &data.base_buffer),
- (3, &sorted_buckets_indices_buffer),
- (4, &data.buckets_matrix_buffer),
- (5, &max_thread_size_accu_buffer),
- // (6, &data.debug_buffer),
- ]),
- );
- // command_encoder.dispatch_thread_groups(
- // MTLSize::new(params.buckets_size as u64 * params.num_window, 1, 1),
- // MTLSize::new(1, 1, 1),
- // );
- command_encoder.dispatch_thread_groups(opt_threadgroups, max_threads_per_group);
- command_encoder.end_encoding();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- });
- let bucket_wise_elapsed = bucket_wise_time.elapsed();
- println!(
- "Bucket wise accumulation time (using {:?} threads): {:?}",
- params.buckets_size as u64 * params.num_window,
- bucket_wise_elapsed
- );
-
- // // debug
- // let debug_data = MetalState::retrieve_contents::(&data.debug_buffer);
- // println!("Debug data: {:?}", debug_data);
-
- // Reduce the buckets_matrix on GPU
- let max_thread_size = params.num_window;
- let opt_threadgroups_amount = max_thread_size
- / config
- .pipelines
- .bucket_wise_accumulation
- .max_total_threads_per_threadgroup()
- + 1;
- let opt_threadgroups = MTLSize::new(opt_threadgroups_amount, 1, 1);
- let max_thread_size_reduc_buffer = config.state.alloc_buffer_data(&[max_thread_size as u32]);
- let reduction_time = Instant::now();
- autoreleasepool(|| {
- let (command_buffer, command_encoder) = config.state.setup_command(
- &config.pipelines.sum_reduction,
- Some(&[
- (0, &data.window_size_buffer),
- (1, &data.scalar_buffer),
- (2, &data.base_buffer),
- (3, &data.buckets_matrix_buffer),
- (4, &data.res_buffer),
- (5, &max_thread_size_reduc_buffer),
- ]),
- );
- // command_encoder
- // .dispatch_thread_groups(MTLSize::new(params.num_window, 1, 1), MTLSize::new(1, 1, 1));
- command_encoder.dispatch_thread_groups(opt_threadgroups, max_threads_per_group);
- command_encoder.end_encoding();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- });
- println!("Reduction time: {:?}", reduction_time.elapsed());
-
- // Sequentially accumulate the msm results on GPU
- let final_time = Instant::now();
- autoreleasepool(|| {
- let (command_buffer, command_encoder) = config.state.setup_command(
- &config.pipelines.final_accumulation,
- Some(&[
- (0, &data.window_size_buffer),
- (1, &data.window_starts_buffer),
- (2, &data.num_windows_buffer),
- (3, &data.res_buffer),
- (4, &data.result_buffer),
- ]),
- );
- command_encoder.dispatch_thread_groups(MTLSize::new(1, 1, 1), MTLSize::new(1, 1, 1));
- command_encoder.end_encoding();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- });
- println!("Final accumulation time: {:?}", final_time.elapsed());
-
- // retrieve and parse the result from GPU
- let msm_result = {
- let raw_limbs = MetalState::retrieve_contents::(&data.result_buffer);
- G::new_unchecked(
- Fq::from_u32_limbs(&raw_limbs[0..8]),
- Fq::from_u32_limbs(&raw_limbs[8..16]),
- Fq::from_u32_limbs(&raw_limbs[16..24]),
- )
- };
-
- Ok(msm_result)
-}
-
-pub fn metal_msm(
- points: &[GAffine],
- scalars: &[ScalarField],
- config: &mut MetalMsmConfig,
-) -> Result {
- let instance = encode_instances(points, scalars, config);
- exec_metal_commands(config, instance)
-}
-
-pub fn benchmark_msm(
- instances: I,
- iterations: u32,
-) -> Result, preprocess::HarnessError>
-where
- I: Iterator
- ,
-{
- println!("Init metal (GPU) state...");
- let init_start = Instant::now();
- let mut metal_config = setup_metal_state();
- let init_duration = init_start.elapsed();
- println!("Done initializing metal (GPU) state in {:?}", init_duration);
-
- let mut instance_durations = Vec::new();
- for instance in instances {
- let points = &instance.0;
- // map each scalar to a ScalarField
- let scalars = &instance
- .1
- .iter()
- .map(|s| ScalarField::new(*s))
- .collect::>();
-
- let mut instance_total_duration = Duration::ZERO;
- for _i in 0..iterations {
- let encoding_data_start = Instant::now();
- println!("Encoding instance to GPU memory...");
- let metal_instance = encode_instances(&points[..], &scalars[..], &mut metal_config);
- let encoding_data_duration = encoding_data_start.elapsed();
- println!("Done encoding data in {:?}", encoding_data_duration);
-
- let msm_start = Instant::now();
- let _result = exec_metal_commands(&metal_config, metal_instance).unwrap();
- instance_total_duration += msm_start.elapsed();
- }
- let instance_avg_duration = instance_total_duration / iterations;
-
- println!(
- "Average time to execute MSM with {} points and {} scalars in {} iterations is: {:?}",
- points.len(),
- scalars.len(),
- iterations,
- instance_avg_duration,
- );
- instance_durations.push(instance_avg_duration);
- }
- Ok(instance_durations)
-}
-
-pub fn run_benchmark(
- instance_size: u32,
- num_instance: u32,
- utils_dir: &str,
-) -> Result {
- // Check if the vectors have been generated
- match preprocess::FileInputIterator::open(&utils_dir) {
- Ok(_) => {
- println!("Vectors already generated");
- }
- Err(_) => {
- preprocess::gen_vectors(instance_size, num_instance, &utils_dir);
- }
- }
-
- let benchmark_data = preprocess::FileInputIterator::open(&utils_dir).unwrap();
- let instance_durations = benchmark_msm(benchmark_data, 1).unwrap();
- // in milliseconds
- let avg_processing_time: f64 = instance_durations
- .iter()
- .map(|d| d.as_secs_f64() * 1000.0)
- .sum::()
- / instance_durations.len() as f64;
-
- println!("Done running benchmark.");
- Ok(BenchmarkResult {
- instance_size: instance_size,
- num_instance: num_instance,
- avg_processing_time: avg_processing_time,
- })
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use ark_ec::{CurveGroup, Group, VariableBaseMSM};
- use ark_serialize::Write;
- use ark_std::UniformRand;
- use std::fs::File;
-
- const INSTANCE_SIZE: u32 = 16;
- const NUM_INSTANCE: u32 = 5;
- const UTILSPATH: &str = "src/msm/utils/vectors";
- const BENCHMARKSPATH: &str = "benchmark_results";
-
- #[test]
- fn test_msm_correctness_medium_sample() {
- let dir = format!("{}/{}/{}x{}", preprocess::get_root_path(), UTILSPATH, 8, 5);
- // Init metal (GPU) state
- let mut metal_config = setup_metal_state();
-
- // Check if the vectors have been generated
- match preprocess::FileInputIterator::open(&dir) {
- Ok(_) => {
- println!("Vectors already generated");
- }
- Err(_) => {
- preprocess::gen_vectors(INSTANCE_SIZE, NUM_INSTANCE, &dir);
- }
- }
-
- let instances = preprocess::FileInputIterator::open(&dir).unwrap();
-
- for (i, instance) in instances.enumerate() {
- let points = &instance.0;
- // map each scalar to a ScalarField
- let scalars = &instance
- .1
- .iter()
- .map(|s| ScalarField::new(*s))
- .collect::>();
- let arkworks_msm = G::msm(&points[..], &scalars[..]).unwrap();
- let metal_msm = metal_msm(&points[..], &scalars[..], &mut metal_config).unwrap();
- assert_eq!(metal_msm, arkworks_msm, "This msm is wrongly computed");
- println!(
- "(pass) {}th instance of size 2^{} is correctly computed",
- i, 8
- );
- }
- }
-
- #[test]
- fn test_benchmark_msm() {
- let dir = format!(
- "{}/{}/{}x{}",
- preprocess::get_root_path(),
- UTILSPATH,
- INSTANCE_SIZE,
- NUM_INSTANCE
- );
-
- // Check if the vectors have been generated
- match preprocess::FileInputIterator::open(&dir) {
- Ok(_) => {
- println!("Vectors already generated");
- }
- Err(_) => {
- preprocess::gen_vectors(INSTANCE_SIZE, NUM_INSTANCE, &dir);
- }
- }
-
- let benchmark_data = preprocess::FileInputIterator::open(&dir).unwrap();
- let result = benchmark_msm(benchmark_data, 1);
- println!("Done running benchmark: {:?}", result);
- }
-
- #[test]
- fn test_run_benchmark() {
- let utils_path = format!(
- "{}/{}/{}x{}",
- preprocess::get_root_path(),
- &UTILSPATH,
- INSTANCE_SIZE,
- NUM_INSTANCE
- );
- let result = run_benchmark(INSTANCE_SIZE, NUM_INSTANCE, &utils_path).unwrap();
- println!("Benchmark result: {:#?}", result);
- }
-
- #[test]
- fn test_run_multi_benchmarks() {
- let output_path = format!(
- "{}/{}/{}_benchmark.txt",
- preprocess::get_root_path(),
- &BENCHMARKSPATH,
- "metal_msm"
- );
- let mut output_file = File::create(output_path).expect("output file creation failed");
- writeln!(output_file, "msm_size,num_msm,avg_processing_time(ms)").unwrap();
-
- let instance_size = vec![8, 12, 16, 18, 20, 22];
- let num_instance = vec![10];
- for size in &instance_size {
- for num in &num_instance {
- let utils_path = format!(
- "{}/{}/{}x{}",
- preprocess::get_root_path(),
- &UTILSPATH,
- *size,
- *num
- );
- let result = run_benchmark(*size, *num, &utils_path).unwrap();
- println!("{}x{} result: {:#?}", *size, *num, result);
- writeln!(
- output_file,
- "{},{},{}",
- result.instance_size, result.num_instance, result.avg_processing_time
- )
- .unwrap();
- }
- }
- }
-}
diff --git a/mopro-msm/src/msm/metal/shader/all.metal b/mopro-msm/src/msm/metal/shader/all.metal
deleted file mode 100644
index 5915eb5c..00000000
--- a/mopro-msm/src/msm/metal/shader/all.metal
+++ /dev/null
@@ -1,11 +0,0 @@
-// This is necessary because pragma once doesn't work as expected
-// and some symbols are being duplicated.
-
-// TODO: Investigate this issue, having .metal sources would be better
-// than headers and a unique source.
-
-#include "fields/fp_bn254.h.metal"
-#include "tests/test_bn254.h.metal"
-#include "tests/test_unsigned_integer.h.metal"
-// #include "utils/parallel_radix_sort.h.metal"
-#include "msm.h.metal"
diff --git a/mopro-msm/src/msm/metal/shader/arithmetics/u128.h.metal b/mopro-msm/src/msm/metal/shader/arithmetics/u128.h.metal
deleted file mode 100644
index 627ff2dd..00000000
--- a/mopro-msm/src/msm/metal/shader/arithmetics/u128.h.metal
+++ /dev/null
@@ -1,235 +0,0 @@
-// https://github.com/andrewmilson/ministark/blob/main/gpu-poly/src/metal/u128.h.metal
-
-#pragma once
-
-#include
-
-class u128
-{
-public:
- u128() = default;
- constexpr u128(int l) : low(l), high(0) {}
- constexpr u128(unsigned long l) : low(l), high(0) {}
- constexpr u128(bool b) : low(b), high(0) {}
- constexpr u128(unsigned long h, unsigned long l) : low(l), high(h) {}
-
- constexpr u128 operator+(const u128 rhs) const
- {
- return u128(high + rhs.high + ((low + rhs.low) < low), low + rhs.low);
- }
-
- constexpr u128 operator+=(const u128 rhs)
- {
- *this = *this + rhs;
- return *this;
- }
-
- constexpr inline u128 operator-(const u128 rhs) const
- {
- return u128(high - rhs.high - ((low - rhs.low) > low), low - rhs.low);
- }
-
- constexpr u128 operator-=(const u128 rhs)
- {
- *this = *this - rhs;
- return *this;
- }
-
- constexpr bool operator==(const u128 rhs) const
- {
- return high == rhs.high && low == rhs.low;
- }
-
- constexpr bool operator!=(const u128 rhs) const
- {
- return !(*this == rhs);
- }
-
- constexpr bool operator<(const u128 rhs) const
- {
- return ((high == rhs.high) && (low < rhs.low)) || (high < rhs.high);
- }
-
- constexpr u128 operator&(const u128 rhs) const
- {
- return u128(high & rhs.high, low & rhs.low);
- }
-
- constexpr u128 operator|(const u128 rhs) const
- {
- return u128(high | rhs.high, low | rhs.low);
- }
-
- constexpr bool operator>(const u128 rhs) const
- {
- return ((high == rhs.high) && (low > rhs.low)) || (high > rhs.high);
- }
-
- constexpr bool operator>=(const u128 rhs) const
- {
- return !(*this < rhs);
- }
-
- constexpr bool operator<=(const u128 rhs) const
- {
- return !(*this > rhs);
- }
-
- constexpr inline u128 operator>>(unsigned shift) const
- {
- uint64_t new_low = (shift == 0) * low
- | (shift == 64) * high
- | ((shift < 64) ^ (shift == 0)) * ((high << (64 - shift)) | (low >> shift))
- | ((shift > 64) & (shift < 128)) * (high >> (shift - 64));
-
- uint64_t new_high = (shift == 0) * high
- | ((shift < 64) ^ (shift == 0)) * (high >> shift);
-
- return u128(new_high, new_low);
-
- // Unoptimized form:
- // if (shift >= 128)
- // return u128(0);
- // else if (shift == 64)
- // return u128(0, high);
- // else if (shift == 0)
- // return *this;
- // else if (shift < 64)
- // return u128(high >> shift, (high << (64 - shift)) | (low >> shift));
- // else if ((128 > shift) && (shift > 64))
- // return u128(0, (high >> (shift - 64)));
- // else
- // return u128(0);
- }
-
- constexpr inline u128 operator<<(unsigned shift) const
- {
- unsigned long new_low = (shift == 0) * low
- | ((shift < 64) ^ (shift == 0)) * (low << shift);
-
- unsigned long new_high = (shift == 0) * high
- | (shift == 64) * low
- | ((shift < 64) ^ (shift == 0)) * (high << shift) | (low >> (64 - shift))
- | ((shift > 64) & (shift < 128)) * (low >> (shift - 64));
-
- return u128(new_high, new_low);
-
- // Unoptimized form:
- // if (shift >= 128)
- // return u128(0);
- // else if (shift == 64)
- // return u128(low, 0);
- // else if (shift == 0)
- // return *this;
- // else if (shift < 64)
- // return u128((high << shift) | (low >> (64 - shift)), low << shift);
- // else if ((128 > shift) && (shift > 64))
- // return u128((low >> (shift - 64)), 0);
- // else
- // return u128(0);
- }
-
- constexpr u128 operator>>=(unsigned rhs)
- {
- *this = *this >> rhs;
- return *this;
- }
-
- u128 operator*(const bool rhs) const
- {
- return u128(high * rhs, low * rhs);
- }
-
- u128 operator*(const u128 rhs) const
- {
- unsigned long t_low_high = low * rhs.high;
- unsigned long t_high = metal::mulhi(low, rhs.low);
- unsigned long t_high_low = high * rhs.low;
- unsigned long t_low = low * rhs.low;
- return u128(t_low_high + t_high_low + t_high, t_low);
-
- // // // split values into 4 32-bit parts
- // // unsigned long top[4] = {high >> 32, high & 0xffffffff, low >> 32, low & 0xffffffff};
- // // unsigned long bottom[4] = {rhs.high >> 32, rhs.high & 0xffffffff, rhs.low >> 32, rhs.low & 0xffffffff};
- // // unsigned long products[4][4];
-
- // // // multiply each component of the values
- // // Alternative:
- // // for(int y = 3; y > -1; y--){
- // // for(int x = 3; x > -1; x--){
- // // products[3 - x][y] = top[x] * bottom[y];
- // // }
- // // }
- // products[0][3] = top[3] * bottom[3];
- // products[1][3] = top[2] * bottom[3];
- // products[2][3] = top[1] * bottom[3];
- // products[3][3] = top[0] * bottom[3];
-
- // products[0][2] = top[3] * bottom[2];
- // products[1][2] = top[2] * bottom[2];
- // products[2][2] = top[1] * bottom[2];
- // // products[3][2] = top[0] * bottom[2];
-
- // products[0][1] = top[3] * bottom[1];
- // products[1][1] = top[2] * bottom[1];
- // // products[2][1] = top[1] * bottom[1];
- // products[3][1] = top[0] * bottom[1];
-
- // products[0][0] = top[3] * bottom[0];
- // // products[1][0] = top[2] * bottom[0];
- // // products[2][0] = top[1] * bottom[0];
- // // products[3][0] = top[0] * bottom[0];
-
- // // first row
- // unsigned long fourth32 = products[0][3] & 0xffffffff;
- // unsigned long third32 = (products[0][2] & 0xffffffff) + (products[0][3] >> 32);
- // unsigned long second32 = (products[0][1] & 0xffffffff) + (products[0][2] >> 32);
- // unsigned long first32 = (products[0][0] & 0xffffffff) + (products[0][1] >> 32);
-
- // // second row
- // third32 += products[1][3] & 0xffffffff;
- // second32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32);
- // first32 += (products[1][1] & 0xffffffff) + (products[1][2] >> 32);
-
- // // third row
- // second32 += products[2][3] & 0xffffffff;
- // first32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32);
-
- // // fourth row
- // first32 += products[3][3] & 0xffffffff;
-
- // // move carry to next digit
- // // third32 += fourth32 >> 32; // TODO: figure out if this is a nop
- // second32 += third32 >> 32;
- // first32 += second32 >> 32;
-
- // // remove carry from current digit
- // // fourth32 &= 0xffffffff; // TODO: figure out if this is a nop
- // // third32 &= 0xffffffff;
- // second32 &= 0xffffffff;
- // // first32 &= 0xffffffff;
-
- // // combine components
- // // return u128((first32 << 32) | second32, (third32 << 32) | fourth32);
- // return u128((first32 << 32) | second32, (third32 << 32) | fourth32);
- }
-
- u128 operator*=(const u128 rhs)
- {
- *this = *this * rhs;
- return *this;
- }
-
- unsigned long high;
- unsigned long low;
-
- // TODO: Could get better performance with smaller limb size
- // Not sure what word size is for M1 GPU
-// #ifdef __LITTLE_ENDIAN__
-// unsigned long low;
-// unsigned long high;
-// #endif
-// #ifdef __BIG_ENDIAN__
-
-// #endif
-};
diff --git a/mopro-msm/src/msm/metal/shader/arithmetics/u256.h.metal b/mopro-msm/src/msm/metal/shader/arithmetics/u256.h.metal
deleted file mode 100644
index 6ff5536c..00000000
--- a/mopro-msm/src/msm/metal/shader/arithmetics/u256.h.metal
+++ /dev/null
@@ -1,262 +0,0 @@
-// https://github.com/andrewmilson/ministark/blob/6e96f6b6c83b7faf38a9e015bbedf2aa7b984092/gpu/src/metal/u256.h.metal
-
-#pragma once
-
-#include
-#include "u128.h.metal"
-
-class u256
-{
-public:
- u256() = default;
- constexpr u256(int l) : low(l), high(0) {}
- constexpr u256(unsigned long l) : low(u128(l)), high(0) {}
- constexpr u256(u128 l) : low(l), high(0) {}
- constexpr u256(bool b) : low(b), high(0) {}
- constexpr u256(u128 h, u128 l) : low(l), high(h) {}
- constexpr u256(unsigned long hh, unsigned long hl, unsigned long lh, unsigned long ll) :
- low(u128(lh, ll)), high(u128(hh, hl)) {}
-
- constexpr u256 operator+(const u256 rhs) const
- {
- return u256(high + rhs.high + ((low + rhs.low) < low), low + rhs.low);
- }
-
- constexpr u256 operator+=(const u256 rhs)
- {
- *this = *this + rhs;
- return *this;
- }
-
- constexpr inline u256 operator-(const u256 rhs) const
- {
- return u256(high - rhs.high - ((low - rhs.low) > low), low - rhs.low);
- }
-
- constexpr u256 operator-=(const u256 rhs)
- {
- *this = *this - rhs;
- return *this;
- }
-
- constexpr bool operator==(const u256 rhs) const
- {
- return high == rhs.high && low == rhs.low;
- }
-
- constexpr bool operator!=(const u256 rhs) const
- {
- return !(*this == rhs);
- }
-
- constexpr bool operator<(const u256 rhs) const
- {
- return ((high == rhs.high) && (low < rhs.low)) || (high < rhs.high);
- }
-
- constexpr u256 operator&(const u256 rhs) const
- {
- return u256(high & rhs.high, low & rhs.low);
- }
-
- constexpr bool operator>(const u256 rhs) const
- {
- return ((high == rhs.high) && (low > rhs.low)) || (high > rhs.high);
- }
-
- constexpr bool operator>=(const u256 rhs) const
- {
- return !(*this < rhs);
- }
-
- constexpr bool operator<=(const u256 rhs) const
- {
- return !(*this > rhs);
- }
-
- inline u256 operator>>(const unsigned shift) const
- {
- u128 new_low = low * (shift == 0)
- | high * (shift == 128)
- | (high << (128 - shift) | (low >> shift)) * ((shift < 128) ^ (shift == 0))
- | (high >> (shift - 128)) * ((shift < 256) & (shift > 128));
-
- u128 new_high = high * (shift == 0)
- | (high >> shift) * ((shift < 128) ^ (shift == 0));
-
- return u256(new_high, new_low);
-
- // Unoptimized form:
- // if (shift >= 256)
- // return u256(0);
- // else if (shift == 128)
- // return u256(0, high);
- // else if (shift == 0)
- // return *this;
- // else if (shift < 128)
- // return u256(high >> shift, (high << (128 - shift)) | (low >> shift));
- // else if ((256 > shift) && (shift > 128))
- // return u256(0, (high >> (shift - 128)));
- // else
- // return u256(0);
- }
-
- constexpr u256 operator>>=(unsigned rhs)
- {
- *this = *this >> rhs;
- return *this;
- }
-
- u256 operator*(const bool rhs) const
- {
- return u256(high * rhs, low * rhs);
- }
-
- u256 operator*(const u256 rhs) const
- {
- // split values into 4 64-bit parts
- u128 top[2] = {u128(low.high), u128(low.low)};
- u128 bottom[3] = {u128(rhs.high.low), u128(rhs.low.high), u128(rhs.low.low)};
-
- unsigned long tmp3_3 = high.high * rhs.low.low;
- unsigned long tmp0_0 = low.low * rhs.high.high;
- unsigned long tmp2_2 = high.low * rhs.low.high;
-
- u128 tmp2_3 = u128(high.low) * bottom[2];
- u128 tmp0_3 = top[1] * bottom[2];
- u128 tmp1_3 = top[0] * bottom[2];
-
- u128 tmp0_2 = top[1] * bottom[1];
- u128 third64 = u128(tmp0_2.low) + u128(tmp0_3.high);
- u128 tmp1_2 = top[0] * bottom[1];
-
- u128 tmp0_1 = top[1] * bottom[0];
- u128 second64 = u128(tmp0_1.low) + u128(tmp0_2.high);
- unsigned long first64 = tmp0_0 + tmp0_1.high;
-
- u128 tmp1_1 = top[0] * bottom[0];
- first64 += tmp1_1.low + tmp1_2.high;
-
- // second row
- third64 += u128(tmp1_3.low);
- second64 += u128(tmp1_2.low) + u128(tmp1_3.high);
-
- // third row
- second64 += u128(tmp2_3.low);
- first64 += tmp2_2 + tmp2_3.high;
-
- // fourth row
- first64 += tmp3_3;
- second64 += u128(third64.high);
- first64 += second64.high;
-
- return u256(u128(first64, second64.low), u128(third64.low, tmp0_3.low));
-
-
- // // unsigned long t_low_high_low = high * rhs.low;
- // // unsigned long t_low_low_high = low * rhs.high;
-
- // // unsigned long t_low = low * rhs.low;
-
- // // u128 t_low = low * rhs.low;
-
- // // unsigned long t_low_high = metal::mulhi(low.low, rhs.low.high);
- // // unsigned long t_high_low = metal::mulhi(low.high, rhs.low.low);
- // // unsigned long t_high = metal::mulhi(low.low, rhs.low.low);
- // // unsigned long t_low = low.low * rhs.low.low;
-
- // // u128 low_low = u128(t_low_high + t_high_low + t_high, t_low);
-
- // // t_low_high = metal::mulhi(low.low, rhs.low.high);
- // // t_high_low = metal::mulhi(low.high, rhs.low.low);
- // // t_high = metal::mulhi(low.low, rhs.low.low);
- // // t_low = low.low * rhs.low.low;
-
- // // return ;
-
- // // split values into 4 64-bit parts
- // u128 top[3] = {u128(high.low), u128(low.high), u128(low.low)};
- // u128 bottom[3] = {u128(rhs.high.low), u128(rhs.low.high), u128(rhs.low.low)};
- // // u128 top[4] = {high >> 32, high & 0xffffffff, low >> 32, low & 0xffffffff};
- // // u128 bottom[4] = {rhs.high >> 32, rhs.high & 0xffffffff, rhs.low >> 32, rhs.low & 0xffffffff};
- // // u128 products[4][4];
-
- // // // multiply each component of the values
- // // Alternative:
- // // for(int y = 3; y > -1; y--){
- // // for(int x = 3; x > -1; x--){
- // // products[3 - x][y] = top[x] * bottom[y];
- // // }
- // // }
- // u128 tmp0_3 = top[2] * bottom[2];
- // u128 tmp1_3 = top[1] * bottom[2];
- // u128 tmp2_3 = top[0] * bottom[2];
- // // u128 tmp3_3 = top[0] * bottom[2];
- // unsigned long tmp3_3 = high.high * rhs.low.low;
- // // unsigned long tmp0 = low.low * rhs.high.high;
-
- // u128 tmp0_2 = top[2] * bottom[1];
- // u128 tmp1_2 = top[1] * bottom[1];
- // // u128 tmp2_2 = top[0] * bottom[1];
- // unsigned long tmp2_2 = high.low * rhs.low.high;
-
-
- // u128 tmp0_1 = top[2] * bottom[0];
- // u128 tmp1_1 = top[1] * bottom[0];
- // // u128 tmp3_1 = top[0] * bottom[0];
-
- // unsigned long tmp0_0 = low.low * rhs.high.high;
-
- // // first row
- // u128 fourth64 = tmp0_3.low;
- // u128 third64 = u128(tmp0_2.low) + u128(tmp0_3.high);
- // u128 second64 = u128(tmp0_1.low) + u128(tmp0_2.high);
- // u128 first64 = u128(tmp0_0) + u128(tmp0_1.high);
-
- // // second row
- // third64 += u128(tmp1_3.low);
- // second64 += u128(tmp1_2.low) + u128(tmp1_3.high);
- // first64 += u128(tmp1_1.low) + u128(tmp1_2.high);
-
- // // third row
- // second64 += u128(tmp2_3.low);
- // first64 += u128(tmp2_2) + u128(tmp2_3.high);
-
- // // fourth row
- // first64 += u128(tmp3_3);
- // second64 += u128(third64.high);
- // first64 += u128(second64.high);
-
- // // remove carry from current digit
- // // fourth64 &= 0xffffffff; // TODO: figure out if this is a nop
- // // third64 &= 0xffffffff;
- // // second64 = u128(second64.low);
- // // first64 &= 0xffffffff;
-
- // // combine components
- // // return u256((first64 << 64) | second64, (third64 << 64) | fourth64);
- // return u256(u128(first64.low, second64.low), u128(third64.low, fourth64.low));
-
- // // return u128((first64.high second64, (third64 << 64) | fourth64);
- }
-
- u256 operator*=(const u256 rhs)
- {
- *this = *this * rhs;
- return *this;
- }
-
- // TODO: Could get better performance with smaller limb size
- // Not sure what word size is for M1 GPU
- u128 high;
- u128 low;
-
-
-// #ifdef __LITTLE_ENDIAN__
-// u128 low;
-// u128 high;
-// #endif
-// #ifdef __BIG_ENDIAN__
-
-// #endif
-};
diff --git a/mopro-msm/src/msm/metal/shader/arithmetics/unsigned_int.h.metal b/mopro-msm/src/msm/metal/shader/arithmetics/unsigned_int.h.metal
deleted file mode 100644
index ced897c5..00000000
--- a/mopro-msm/src/msm/metal/shader/arithmetics/unsigned_int.h.metal
+++ /dev/null
@@ -1,262 +0,0 @@
-#ifndef unsigned_int_h
-#define unsigned_int_h
-
-#include
-
-template
-struct UnsignedInteger {
- metal::array m_limbs;
-
- constexpr UnsignedInteger() = default;
-
- constexpr static UnsignedInteger from_int(uint32_t n) {
- UnsignedInteger res;
- res.m_limbs = {};
- res.m_limbs[NUM_LIMBS - 1] = n;
- return res;
- }
-
- constexpr static UnsignedInteger from_int(uint64_t n) {
- UnsignedInteger res;
- res.m_limbs = {};
- res.m_limbs[NUM_LIMBS - 2] = (uint32_t)(n >> 32);
- res.m_limbs[NUM_LIMBS - 1] = (uint32_t)(n & 0xFFFFFFFF);
- return res;
- }
-
- constexpr static UnsignedInteger from_bool(bool b) {
- UnsignedInteger res;
- res.m_limbs = {};
- if (b) {
- res.m_limbs[NUM_LIMBS - 1] = 1;
- }
- return res;
- }
-
- constexpr static UnsignedInteger from_high_low(UnsignedInteger high, UnsignedInteger low) {
- UnsignedInteger res = low;
-
- for (uint64_t i = 0; i < NUM_LIMBS / 2; i++) {
- res.m_limbs[i] = high.m_limbs[i + NUM_LIMBS / 2];
- }
-
- return res;
- }
-
- constexpr UnsignedInteger low() const {
- UnsignedInteger res = *this;
-
- for (uint64_t i = 0; i < NUM_LIMBS / 2; i++) {
- res.m_limbs[i] = 0;
- }
-
- return res;
- }
-
- constexpr UnsignedInteger high() const {
- UnsignedInteger res;
- res.m_limbs = {};
-
- for (uint64_t i = 0; i < NUM_LIMBS / 2; i++) {
- res.m_limbs[NUM_LIMBS / 2 + i] = m_limbs[i];
- }
-
- return res;
- }
-
- static UnsignedInteger max() {
- UnsignedInteger res = {};
-
- for (uint64_t i = 0; i < NUM_LIMBS; i++) {
- res.m_limbs[i] = 0xFFFFFFFF;
- }
-
- return res;
- }
-
- constexpr UnsignedInteger operator+(const UnsignedInteger rhs) const
- {
- metal::array limbs {};
- uint64_t carry = 0;
- int i = NUM_LIMBS;
-
- while (i > 0) {
- uint64_t c = uint64_t(m_limbs[i - 1]) + uint64_t(rhs.m_limbs[i - 1]) + carry;
- limbs[i - 1] = c & 0xFFFFFFFF;
- carry = c >> 32;
- i -= 1;
- }
-
- return UnsignedInteger {limbs};
- }
-
- constexpr bool operator==(const UnsignedInteger rhs) const
- {
- for (uint32_t i = 0; i < NUM_LIMBS; i++) {
- if (m_limbs[i] != rhs.m_limbs[i]) {
- return false;
- }
- }
- return true;
- }
-
- constexpr UnsignedInteger operator+=(const UnsignedInteger rhs)
- {
- *this = *this + rhs;
- return *this;
- }
-
- constexpr UnsignedInteger operator-(const UnsignedInteger rhs) const
- {
- metal::array limbs {};
- uint64_t carry = 0;
- uint64_t i = NUM_LIMBS;
-
- while (i > 0) {
- i -= 1;
- int64_t c = (int64_t)(m_limbs[i]) - (int64_t)(rhs.m_limbs[i]) + carry;
- limbs[i] = c & 0xFFFFFFFF;
- carry = c < 0 ? -1 : 0;
- }
-
- return UnsignedInteger {limbs};
- }
-
- constexpr UnsignedInteger operator-=(const UnsignedInteger rhs)
- {
- *this = *this - rhs;
- return *this;
- }
-
- constexpr UnsignedInteger operator*(const UnsignedInteger rhs) const
- {
- long int INT_NUM_LIMBS = (long int)NUM_LIMBS;
- uint64_t n = 0;
- uint64_t t = 0;
-
- for (long int i = INT_NUM_LIMBS - 1; i >= 0; i--) {
- if (m_limbs[i] != 0) {
- n = INT_NUM_LIMBS - 1 - i;
- }
- if (rhs.m_limbs[i] != 0) {
- t = INT_NUM_LIMBS - 1 - i;
- }
- }
-
- metal::array limbs {};
-
- uint64_t carry = 0;
- for (uint64_t i = 0; i <= t; i++) {
- for (uint64_t j = 0; j <= n; j++) {
- uint64_t uv = (uint64_t)(limbs[NUM_LIMBS - 1 - (i + j)])
- + (uint64_t)(m_limbs[NUM_LIMBS - 1 - j])
- * (uint64_t)(rhs.m_limbs[NUM_LIMBS - 1 - i])
- + carry;
- carry = uv >> 32;
- limbs[NUM_LIMBS - 1 - (i + j)] = uv & 0xFFFFFFFF;
- }
- if (i + n + 1 < NUM_LIMBS) {
- limbs[NUM_LIMBS - 1 - (i + n + 1)] = carry & 0xFFFFFFFF;
- carry = 0;
- }
- }
-
- return UnsignedInteger {limbs};
- }
-
- uint64_t cast(uint32_t n) {
- return ((uint64_t)n) >> 32;
- }
-
- constexpr UnsignedInteger operator*=(const UnsignedInteger rhs)
- {
- *this = *this * rhs;
- return *this;
- }
-
- constexpr UnsignedInteger operator<<(const uint32_t times) const
- {
- metal::array limbs {};
- uint32_t a = times / 32;
- uint32_t b = times % 32;
-
- if (b == 0) {
- int64_t i = 0;
- while (i < (int64_t)NUM_LIMBS - (int64_t)a) {
- limbs[i] = m_limbs[a + i];
- i += 1;
- }
- } else {
- limbs[NUM_LIMBS - 1 - a] = m_limbs[NUM_LIMBS - 1] << b;
- uint64_t i = a + 1;
- while (i < NUM_LIMBS) {
- limbs[NUM_LIMBS - 1 - i] = (m_limbs[NUM_LIMBS - 1 - i + a] << b) | (m_limbs[NUM_LIMBS - i + a] >> (32 - b));
- i += 1;
- }
- }
-
- return UnsignedInteger {limbs};
- }
-
- constexpr UnsignedInteger operator>>(const uint32_t times) const
- {
- metal::array limbs {};
- uint32_t a = times / 32;
- uint32_t b = times % 32;
-
- if (b == 0) {
- int64_t i = 0;
- while (i < (int64_t)NUM_LIMBS - (int64_t)a) {
- limbs[a + i] = m_limbs[i];
- i += 1;
- }
- } else {
- limbs[a] = m_limbs[0] >> b;
- uint64_t i = a + 1;
- while (i < NUM_LIMBS) {
- limbs[i] = (m_limbs[i - a - 1] << (32 - b)) | (m_limbs[i - a] >> b);
- i += 1;
- }
- }
-
- return UnsignedInteger {limbs};
- }
-
- constexpr bool operator>(const UnsignedInteger rhs) const {
- for (uint64_t i = 0; i < NUM_LIMBS; i++) {
- if (m_limbs[i] > rhs.m_limbs[i]) return true;
- if (m_limbs[i] < rhs.m_limbs[i]) return false;
- }
-
- return false;
- }
-
- constexpr bool operator>=(const UnsignedInteger rhs) {
- for (uint64_t i = 0; i < NUM_LIMBS; i++) {
- if (m_limbs[i] > rhs.m_limbs[i]) return true;
- if (m_limbs[i] < rhs.m_limbs[i]) return false;
- }
-
- return true;
- }
-
- constexpr bool operator<(const UnsignedInteger rhs) const {
- for (uint64_t i = 0; i < NUM_LIMBS; i++) {
- if (m_limbs[i] > rhs.m_limbs[i]) return false;
- if (m_limbs[i] < rhs.m_limbs[i]) return true;
- }
-
- return false;
- }
-
- constexpr bool operator<=(const UnsignedInteger rhs) const {
- for (uint64_t i = 0; i < NUM_LIMBS; i++) {
- if (m_limbs[i] > rhs.m_limbs[i]) return false;
- if (m_limbs[i] < rhs.m_limbs[i]) return true;
- }
-
- return true;
- }
-};
-
-#endif /* unsigned_int_h */
diff --git a/mopro-msm/src/msm/metal/shader/curves/bn254.h.metal b/mopro-msm/src/msm/metal/shader/curves/bn254.h.metal
deleted file mode 100644
index 3279aa33..00000000
--- a/mopro-msm/src/msm/metal/shader/curves/bn254.h.metal
+++ /dev/null
@@ -1,52 +0,0 @@
-#pragma once
-
-#include "ec_point.h.metal"
-#include "../fields/fp_bn254.h.metal"
-#include "../tests/test_bn254.h.metal"
-
-namespace {
- typedef ECPoint BN254;
- typedef UnsignedInteger<8> u256;
-}
-
-template [[ host_name("bn254_add") ]]
-[[kernel]] void bn254_add(
- constant FpBN254*,
- constant FpBN254*,
- device FpBN254*
-);
-
-template [[ host_name("fp_bn254_add") ]]
-[[kernel]] void fp_bn254_add(
- constant FpBN254&,
- constant FpBN254&,
- device FpBN254&
-);
-
-template [[ host_name("fp_bn254_sub") ]]
-[[kernel]] void fp_bn254_sub(
- constant FpBN254&,
- constant FpBN254&,
- device FpBN254&
-);
-
-template [[ host_name("fp_bn254_mul") ]]
-[[kernel]] void fp_bn254_mul(
- constant FpBN254&,
- constant FpBN254&,
- device FpBN254&
-);
-
-template [[ host_name("fp_bn254_pow") ]]
-[[kernel]] void fp_bn254_pow(
- constant FpBN254&,
- constant uint32_t&,
- device FpBN254&
-);
-
-template [[ host_name("fp_bn254_neg") ]]
-[[kernel]] void fp_bn254_neg(
- constant FpBN254&,
- constant uint32_t&,
- device FpBN254&
-);
diff --git a/mopro-msm/src/msm/metal/shader/curves/ec_point.h.metal b/mopro-msm/src/msm/metal/shader/curves/ec_point.h.metal
deleted file mode 100644
index db961d96..00000000
--- a/mopro-msm/src/msm/metal/shader/curves/ec_point.h.metal
+++ /dev/null
@@ -1,145 +0,0 @@
-#pragma once
-
-template
-class ECPoint {
-public:
- Fp x;
- Fp y;
- Fp z;
-
- constexpr ECPoint() : ECPoint(ECPoint::neutral_element()) {}
- constexpr ECPoint(Fp _x, Fp _y, Fp _z) : x(_x), y(_y), z(_z) {}
-
- constexpr ECPoint operator+(const ECPoint other) const {
- if (is_neutral_element(*this)) {
- return other;
- }
- if (is_neutral_element(other)) {
- return *this;
- }
-
- // Z1Z1 = Z1^2
- Fp z1z1 = z * z;
-
- // Z2Z2 = Z2^2
- Fp z2z2 = other.z * other.z;
-
- // U1 = X1 * Z2Z2
- Fp u1 = x * z2z2;
-
- // U2 = X2 * Z1Z1
- Fp u2 = other.x * z1z1;
-
- // S1 = Y1 * Z2 * Z2Z2
- Fp s1 = y * other.z * z2z2;
-
- // S2 = Y2 * Z1 * Z1Z1
- Fp s2 = other.y * z * z1z1;
-
- if (u1 == u2 && s1 == s2) {
- // The points are equal, so we double
- return double_in_place();
- }
-
- // H = U2 - U1
- Fp h = u2 - u1;
-
- // I = (2 * H)^2
- Fp i = (h + h) * (h + h);
-
- // J = H * I
- Fp j = h * i;
-
- // r = 2 * (S2 - S1)
- Fp r = (s2 - s1) + (s2 - s1);
-
- // V = U1 * I
- Fp v = u1 * i;
-
- // X3 = r^2 - J - 2 * V
- Fp x3 = r * r - j - (v + v);
-
- // Y3 = r * (V - X3) - 2 * S1 * J
- Fp y3 = r * (v - x3) - (s1 + s1) * j;
-
- // Z3 = (Z1 + Z2)^2 - Z1Z1 - Z2Z2) * H
- Fp z3 = ((z + other.z) * (z + other.z) - z1z1 - z2z2) * h;
-
- return ECPoint(x3, y3, z3);
- }
-
- void operator+=(const ECPoint other) {
- *this = *this + other;
- }
-
- static ECPoint neutral_element() {
- return ECPoint(Fp(1), Fp(1), Fp(0)); // Updated to new neutral element (1, 1, 0)
- }
-
- ECPoint operate_with_self(uint64_t exponent) const {
- ECPoint result = neutral_element();
- ECPoint base = ECPoint(x, y, z);
-
- while (exponent > 0) {
- if ((exponent & 1) == 1) {
- result = result + base;
- }
- exponent = exponent >> 1;
- base = base + base;
- }
-
- return result;
- }
-
- constexpr ECPoint operator*(uint64_t exponent) const {
- return operate_with_self(exponent);
- }
-
- constexpr void operator*=(uint64_t exponent) {
- *this = operate_with_self(exponent);
- }
-
- constexpr ECPoint neg() const {
- return ECPoint(x, y.neg(), z);
- }
-
- constexpr bool is_neutral_element(const ECPoint a_point) const {
- return a_point.z == Fp(0); // Updated to check for (1, 1, 0)
- }
-
- constexpr ECPoint double_in_place() const {
- if (is_neutral_element(*this)) {
- return *this;
- }
-
- // Doubling formulas
- Fp a_fp = Fp(A_CURVE).to_montgomery();
- Fp two = Fp(2).to_montgomery();
- Fp three = Fp(3).to_montgomery();
-
- Fp eight = Fp(8).to_montgomery();
-
- Fp xx = x * x; // x^2
- Fp yy = y * y; // y^2
- Fp yyyy = yy * yy; // y^4
- Fp zz = z * z; // z^2
-
- // S = 2 * ((X1 + YY)^2 - XX - YYYY)
- Fp s = two * (((x + yy) * (x + yy)) - xx - yyyy);
-
- // M = 3 * XX + a * ZZ ^ 2
- Fp m = (three * xx) + (a_fp * (zz * zz));
-
- // X3 = T = M^2 - 2*S
- Fp x3 = (m * m) - (two * s);
-
- // Z3 = (Y + Z) ^ 2 - YY - ZZ
- // or Z3 = 2 * Y * Z
- Fp z3 = two * y * z;
-
- // Y3 = M*(S-X3)-8*YYYY
- Fp y3 = m * (s - x3) - eight * yyyy;
-
- return ECPoint(x3, y3, z3);
- }
-};
diff --git a/mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal b/mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal
deleted file mode 100644
index 45219c2c..00000000
--- a/mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal
+++ /dev/null
@@ -1,291 +0,0 @@
-#pragma once
-
-#include "../arithmetics/unsigned_int.h.metal"
-
-
-namespace {
- // 8 limbs of 32 bits uint
- typedef UnsignedInteger<8> u256;
-}
-
-/* Constants for bn254 field operations
- * N: scalar field modulus
- * R_SQUARED: R^2 mod N
- * R_SUB_N: R - N
- * MU: Montgomery Multiplication Constant = -N^{-1} mod (2^32)
- *
- * For bn254, the modulus is "21888242871839275222246405745257275088548364400416034343698204186575808495617" [1, 2]
- * We use 8 limbs of 32 bits unsigned integers to represent the constanst
- *
- * References:
- * [1] https://github.com/arkworks-rs/algebra/blob/065cd24fc5ae17e024c892cee126ad3bd885f01c/curves/bn254/src/lib.rs
- * [2] https://github.com/scipr-lab/libff/blob/develop/libff/algebra/curves/alt_bn128/alt_bn128.sage
- */
-
-constexpr static const constant u256 N = {
- 0x30644E72, 0xE131A029,
- 0xB85045B6, 0x8181585D,
- 0x97816A91, 0x6871CA8D,
- 0x3C208C16, 0xD87CFD47
-};
-
-constexpr static const constant u256 R_SQUARED = {
- 0x06D89F71, 0xCAB8351F,
- 0x47AB1EFF, 0x0A417FF6,
- 0xB5E71911, 0xD44501FB,
- 0xF32CFC5B, 0x538AFA89
-};
-
-constexpr static const constant u256 R_SUB_N = {
- 0xCF9BB18D, 0x1ECE5FD6,
- 0x47AFBA49, 0x7E7EA7A2,
- 0x687E956E, 0x978E3572,
- 0xC3DF73E9, 0x278302B9
-};
-
-constexpr static const constant uint64_t MU = 3834012553;
-
-class FpBN254 {
-public:
- u256 inner;
- constexpr FpBN254() = default;
- constexpr FpBN254(uint64_t v) : inner{u256::from_int(v)} {}
- constexpr FpBN254(u256 v) : inner{v} {}
-
- constexpr explicit operator u256() const {
- return inner;
- }
-
- constexpr FpBN254 operator+(const FpBN254 rhs) const {
- return FpBN254(add(inner, rhs.inner));
- }
-
- constexpr FpBN254 operator-(const FpBN254 rhs) const {
- return FpBN254(sub(inner, rhs.inner));
- }
-
- constexpr FpBN254 operator*(const FpBN254 rhs) const {
- return FpBN254(mul(inner, rhs.inner));
- }
-
- constexpr bool operator==(const FpBN254 rhs) const {
- return inner == rhs.inner;
- }
-
- constexpr bool operator!=(const FpBN254 rhs) const {
- return !(inner == rhs.inner);
- }
-
- constexpr explicit operator uint32_t() const {
- return inner.m_limbs[7];
- }
-
- FpBN254 operator>>(const uint32_t rhs) const {
- return FpBN254(inner >> rhs);
- }
-
- FpBN254 operator<<(const uint32_t rhs) const {
- return FpBN254(inner << rhs);
- }
-
- constexpr static FpBN254 one() {
- const FpBN254 ONE = FpBN254::mul(u256::from_int((uint32_t) 1), R_SQUARED);
- return ONE;
- }
-
- constexpr FpBN254 to_montgomery() {
- return mul(inner, R_SQUARED);
- }
-
- FpBN254 pow(uint32_t exp) const {
- FpBN254 const ONE = one();
- FpBN254 res = ONE;
- FpBN254 power = *this;
-
- while (exp > 0) {
- if (exp & 1) {
- res = res * power;
- }
- exp >>= 1;
- power = power * power;
- }
-
- return res;
- }
-
- FpBN254 inverse() {
- // Generate by the command: addchain search '21888242871839275222246405745257275088696311157297823662689037894645226208583 - 2'
- // https://github.com/mmcloughlin/addchain
-
- // addchain: expr: "21888242871839275222246405745257275088696311157297823662689037894645226208583 - 2"
- // addchain: hex: 30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd45
- // addchain: dec: 21888242871839275222246405745257275088696311157297823662689037894645226208581
- // addchain: best: opt(dictionary(sliding_window(8),heuristic(use_first(halving,delta_largest))))
- // addchain: cost: 303
- // _10 = 2*1
- // _11 = 1 + _10
- // _101 = _10 + _11
- // _110 = 1 + _101
- // _1000 = _10 + _110
- // _1101 = _101 + _1000
- // _10010 = _101 + _1101
- // _10011 = 1 + _10010
- // _10100 = 1 + _10011
- // _10111 = _11 + _10100
- // _11100 = _101 + _10111
- // _100000 = _1101 + _10011
- // _100011 = _11 + _100000
- // _101011 = _1000 + _100011
- // _101111 = _10011 + _11100
- // _1000001 = _10010 + _101111
- // _1010011 = _10010 + _1000001
- // _1011011 = _1000 + _1010011
- // _1100001 = _110 + _1011011
- // _1110101 = _10100 + _1100001
- // _10010001 = _11100 + _1110101
- // _10010101 = _100000 + _1110101
- // _10110101 = _100000 + _10010101
- // _10111011 = _110 + _10110101
- // _11000001 = _110 + _10111011
- // _11000011 = _10 + _11000001
- // _11010011 = _10010 + _11000001
- // _11100001 = _100000 + _11000001
- // _11100011 = _10 + _11100001
- // _11100111 = _110 + _11100001
- // i57 = ((_11000001 << 8 + _10010001) << 10 + _11100111) << 7
- // i76 = ((_10111 + i57) << 9 + _10011) << 7 + _1101
- // i109 = ((i76 << 14 + _1010011) << 9 + _11100001) << 8
- // i127 = ((_1000001 + i109) << 10 + _1011011) << 5 + _1101
- // i161 = ((i127 << 8 + _11) << 12 + _101011) << 12
- // i186 = ((_10111011 + i161) << 8 + _101111) << 14 + _10110101
- // i214 = ((i186 << 9 + _10010001) << 5 + _1101) << 12
- // i236 = ((_11100011 + i214) << 8 + _10010101) << 11 + _11010011
- // i268 = ((i236 << 7 + _1100001) << 11 + _100011) << 12
- // i288 = ((_1011011 + i268) << 9 + _11000011) << 8 + _11100111
- // return (i288 << 7 + _1110101) << 6 + _101
-
- u256 _10 = mul(inner, inner);
- u256 _11 = mul(_10, inner);
- u256 _101 = mul(_10, _11);
- u256 _110 = mul(inner, _101);
- u256 _1000 = mul(_10, _110);
- u256 _1101 = mul(_101, _1000);
- u256 _10010 = mul(_101, _1101);
- u256 _10011 = mul(inner, _10010);
- u256 _10100 = mul(inner, _10011);
- u256 _10111 = mul(_11, _10100);
- u256 _11100 = mul(_101, _10111);
- u256 _100000 = mul(_1101, _10011);
- u256 _100011 = mul(_11, _100000);
- u256 _101011 = mul(_1000, _100011);
- u256 _101111 = mul(_10011, _11100);
- u256 _1000001 = mul(_10010, _101111);
- u256 _1010011 = mul(_10010, _1000001);
- u256 _1011011 = mul(_1000, _1010011);
- u256 _1100001 = mul(_110, _1011011);
- u256 _1110101 = mul(_10100, _1100001);
- u256 _10010001 = mul(_11100, _1110101);
- u256 _10010101 = mul(_100000, _1110101);
- u256 _10110101 = mul(_100000, _10010101);
- u256 _10111011 = mul(_110, _10110101);
- u256 _11000001 = mul(_110, _10111011);
- u256 _11000011 = mul(_10, _11000001);
- u256 _11010011 = mul(_10010, _11000001);
- u256 _11100001 = mul(_100000, _11000001);
- u256 _11100011 = mul(_10, _11100001);
- u256 _11100111 = mul(_110, _11100001);
- u256 i57 = sqn<7>(mul(sqn<10>(mul(sqn<8>(_11000001),_10010001)),_11100111));
- u256 i76 = mul(sqn<7>(mul(sqn<9>(mul(_10111,i57)),_10011)), _10011);
- u256 i109 = sqn<8>(mul(sqn<9>(mul(sqn<14>(i76),_1010011)),_11100001));
- u256 i127 = mul(sqn<5>(mul(sqn<10>(mul(_1000001,i109)),_1011011)),_1101);
- u256 i161 = sqn<12>(mul(sqn<12>(mul(sqn<8>(i127),_11)),_101011));
- u256 i186 = mul(sqn<14>(mul(sqn<8>(mul(_10111011,i161)),_101111)),_10110101);
- u256 i214 = sqn<12>(mul(sqn<5>(mul(sqn<9>(i186),_10010001)),_1101));
- u256 i236 = mul(sqn<11>(mul(sqn<8>(mul(_11100011,i214)),_10010101)),_11010011);
- u256 i268 = sqn<12>(mul(sqn<11>(mul(sqn<7>(i236),_1100001)),_100011));
- u256 i288 = mul(sqn<8>(mul(sqn<9>(mul(_1011011,i268)),_11000011)),_11100111);
- return FpBN254(mul(sqn<6>(mul(sqn<7>(i288),_1110101)),_101));
- }
-
- FpBN254 neg() {
- return FpBN254(sub(u256::from_int((uint32_t)0), inner));
- }
-
-private:
- template
- u256 sqn(u256 base) const {
- u256 result = base;
-#pragma unroll
- for (uint32_t i = 0; i < N_ACC; i++) {
- result = mul(result, result);
- }
- return result;
- }
-
- inline u256 add(const u256 lhs, const u256 rhs) const {
- u256 addition = lhs + rhs;
- u256 res = addition;
-
- return res - u256::from_int((uint64_t)(addition >= N)) * N + u256::from_int((uint64_t)(addition < lhs)) * R_SUB_N;
- }
-
- inline u256 sub(const u256 lhs, const u256 rhs) const {
- return add(lhs, ((u256)N) - rhs);
- }
-
- // Compute multiplication by performing single round of Montgomery reduction
- constexpr static u256 mul(const u256 a, const u256 b) {
- constexpr uint64_t NUM_LIMBS = 8;
- metal::array t = {};
- metal::array t_extra = {};
-
- u256 q = N;
-
- uint64_t i = NUM_LIMBS;
-
- while (i > 0) {
- i -= 1;
- uint64_t c = 0;
-
- uint64_t cs = 0;
- uint64_t j = NUM_LIMBS;
- while (j > 0) {
- j -= 1;
- cs = (uint64_t)t[j] + (uint64_t)a.m_limbs[j] * (uint64_t)b.m_limbs[i] + c;
- c = cs >> 32;
- t[j] = (uint32_t)((cs << 32) >> 32);
- }
-
- cs = (uint64_t)t_extra[1] + c;
- t_extra[0] = (uint32_t)(cs >> 32);
- t_extra[1] = (uint32_t)((cs << 32) >> 32);
-
- uint64_t m = (((uint64_t)t[NUM_LIMBS - 1] * MU) << 32) >> 32;
-
- c = ((uint64_t)t[NUM_LIMBS - 1] + m * (uint64_t)q.m_limbs[NUM_LIMBS - 1]) >> 32;
-
- j = NUM_LIMBS - 1;
- while (j > 0) {
- j -= 1;
- cs = (uint64_t)t[j] + m * (uint64_t)q.m_limbs[j] + c;
- c = cs >> 32;
- t[j + 1] = (uint32_t)((cs << 32) >> 32);
- }
-
- cs = (uint64_t)t_extra[1] + c;
- c = cs >> 32;
- t[0] = (uint32_t)((cs << 32) >> 32);
-
- t_extra[1] = t_extra[0] + (uint32_t)c;
- }
-
- u256 result {t};
-
- uint64_t overflow = t_extra[0] > 0;
- if (overflow || q <= result) {
- result = result - q;
- }
-
- return result;
- }
-};
diff --git a/mopro-msm/src/msm/metal/shader/fields/fp_u256.h.metal b/mopro-msm/src/msm/metal/shader/fields/fp_u256.h.metal
deleted file mode 100644
index 2bad59ea..00000000
--- a/mopro-msm/src/msm/metal/shader/fields/fp_u256.h.metal
+++ /dev/null
@@ -1,210 +0,0 @@
-#pragma once
-// https://github.com/andrewmilson/ministark/blob/main/gpu-poly/src/metal/felt_u256.h.metal
-
-#include "../arithmetics/u256.h.metal"
-
-template <
- /* =N **/ unsigned long N_0, unsigned long N_1, unsigned long N_2, unsigned long N_3,
- /* =R_SQUARED **/ unsigned long R_SQUARED_0, unsigned long R_SQUARED_1, unsigned long R_SQUARED_2, unsigned long R_SQUARED_3,
- /* =N_PRIME **/ unsigned long N_PRIME_0, unsigned long N_PRIME_1, unsigned long N_PRIME_2, unsigned long N_PRIME_3>
-class Fp256 {
-public:
- Fp256() = default;
- constexpr Fp256(unsigned long v) : inner(v) {}
- constexpr Fp256(u256 v) : inner(v) {}
-
- constexpr explicit operator u256() const
- {
- return inner;
- }
-
- constexpr Fp256 operator+(const Fp256 rhs) const
- {
- return Fp256(add(inner, rhs.inner));
- }
-
- constexpr Fp256 operator-(const Fp256 rhs) const
- {
- return Fp256(sub(inner, rhs.inner));
- }
-
- Fp256 operator*(const Fp256 rhs) const
- {
- return Fp256(mul(inner, rhs.inner));
- }
-
- bool operator==(const Fp256 rhs) const
- {
- return inner == rhs.inner;
- }
-
- bool operator!=(const Fp256 rhs) const
- {
- return inner != rhs.inner;
- }
-
- // TODO: make method for all fields
- constexpr Fp256 pow(unsigned exp) const
- {
- // TODO find a way to generate on compile time
- Fp256 const ONE = mul(u256(1), R_SQUARED);
- Fp256 res = ONE;
- Fp256 initial = *this;
-
- while (exp > 0)
- {
- if (exp & 1)
- {
- res = res * initial;
- }
- exp >>= 1;
- initial = initial * initial;
- }
-
- return res;
- }
-
- constexpr Fp256 inverse()
- {
- // used addchain
- // https://github.com/mmcloughlin/addchain
- u256 _10 = mul(inner, inner);
- u256 _11 = mul(_10, inner);
- u256 _1100 = sqn<2>(_11);
- u256 _1101 = mul(inner, _1100);
- u256 _1111 = mul(_10, _1101);
- u256 _11001 = mul(_1100, _1101);
- u256 _110010 = mul(_11001, _11001);
- u256 _110011 = mul(inner, _110010);
- u256 _1000010 = mul(_1111, _110011);
- u256 _1001110 = mul(_1100, _1000010);
- u256 _10000001 = mul(_110011, _1001110);
- u256 _11001111 = mul(_1001110, _10000001);
- u256 i14 = mul(_11001111, _11001111);
- u256 i15 = mul(_10000001, i14);
- u256 i16 = mul(i14, i15);
- u256 x10 = mul(_1000010, i16);
- u256 i27 = sqn<10>(x10);
- u256 i28 = mul(i16, i27);
- u256 i38 = sqn<10>(i27);
- u256 i39 = mul(i28, i38);
- u256 i49 = sqn<10>(i38);
- u256 i50 = mul(i39, i49);
- u256 i60 = sqn<10>(i49);
- u256 i61 = mul(i50, i60);
- u256 i72 = mul(sqn<10>(i60), i61);
- u256 x60 = mul(_1000010, i72);
- u256 i76 = sqn<2>(mul(i72, x60));
- u256 x64 = mul(mul(i15, i76), i76);
- u256 i208 = mul(sqn<64>(mul(sqn<63>(mul(i15, x64)), x64)), x64);
- return Fp256(mul(sqn<60>(i208), x60));
- }
-
- constexpr Fp256 neg()
- {
- // TODO: can improve
- return Fp256(sub(0, inner));
- }
-
-private:
- u256 inner;
-
- constexpr static const constant u256 N = u256(N_0, N_1, N_2, N_3);
- constexpr static const constant u256 R_SQUARED = u256(R_SQUARED_0, R_SQUARED_1, R_SQUARED_2, R_SQUARED_3);
- constexpr static const constant u256 N_PRIME = u256(N_PRIME_0, N_PRIME_1, N_PRIME_2, N_PRIME_3);
-
- // Equates to `(1 << 256) - N`
- constexpr static const constant u256 R_SUB_N =
- u256(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF) - N + u256(1);
-
- template
- u256 sqn(u256 base) const {
- u256 result = base;
-#pragma unroll
- for (unsigned i = 0; i < N_ACC; i++) {
- result = mul(result, result);
- }
- return result;
- }
-
- // Computes `lhs + rhs mod N`
- // Returns value in range [0,N)
- inline u256 add(const u256 lhs, const u256 rhs) const
- {
- u256 addition = (lhs + rhs);
- u256 res = addition;
- // TODO: determine if an if statement here are more optimal
- return res - u256(addition >= N) * N + u256(addition < lhs) * R_SUB_N;
- }
-
- // Computes `lhs - rhs mod N`
- // Assumes `rhs` value in range [0,N)
- inline u256 sub(const u256 lhs, const u256 rhs) const
- {
- // TODO: figure what goes on here with "constant" scope variables
- return add(lhs, ((u256)N) - rhs);
- }
-
- // Computes `lhs * rhs mod M`
- //
- // Essential that inputs are already in the range [0,N) and are in montgomery
- // form. Multiplication performs single round of montgomery reduction.
- //
- // Reference:
- // - https://en.wikipedia.org/wiki/Montgomery_modular_multiplication (REDC)
- // - https://www.youtube.com/watch?v=2UmQDKcelBQ
- u256 mul(const u256 lhs, const u256 rhs) const
- {
- u256 lhs_low = lhs.low;
- u256 lhs_high = lhs.high;
- u256 rhs_low = rhs.low;
- u256 rhs_high = rhs.high;
-
- u256 partial_t_high = lhs_high * rhs_high;
- u256 partial_t_mid_a = lhs_high * rhs_low;
- u256 partial_t_mid_a_low = partial_t_mid_a.low;
- u256 partial_t_mid_a_high = partial_t_mid_a.high;
- u256 partial_t_mid_b = rhs_high * lhs_low;
- u256 partial_t_mid_b_low = partial_t_mid_b.low;
- u256 partial_t_mid_b_high = partial_t_mid_b.high;
- u256 partial_t_low = lhs_low * rhs_low;
-
- u256 tmp = partial_t_mid_a_low +
- partial_t_mid_b_low + partial_t_low.high;
- u256 carry = tmp.high;
- u256 t_low = u256(tmp.low, partial_t_low.low);
- u256 t_high = partial_t_high + partial_t_mid_a_high + partial_t_mid_b_high + carry;
-
- // Compute `m = T * N' mod R`
- u256 m = t_low * N_PRIME;
-
- // Compute `t = (T + m * N) / R`
- u256 n = N;
- u256 n_low = n.low;
- u256 n_high = n.high;
- u256 m_low = m.low;
- u256 m_high = m.high;
-
- u256 partial_mn_high = m_high * n_high;
- u256 partial_mn_mid_a = m_high * n_low;
- u256 partial_mn_mid_a_low = partial_mn_mid_a.low;
- u256 partial_mn_mid_a_high = partial_mn_mid_a.high;
- u256 partial_mn_mid_b = n_high * m_low;
- u256 partial_mn_mid_b_low = partial_mn_mid_b.low;
- u256 partial_mn_mid_b_high = partial_mn_mid_b.high;
- u256 partial_mn_low = m_low * n_low;
-
- tmp = partial_mn_mid_a_low + partial_mn_mid_b_low + u256(partial_mn_low.high);
- carry = tmp.high;
- u256 mn_low = u256(tmp.low, partial_mn_low.low);
- u256 mn_high = partial_mn_high + partial_mn_mid_a_high + partial_mn_mid_b_high + carry;
-
- u256 overflow = mn_low + t_low < mn_low;
- u256 t_tmp = t_high + overflow;
- u256 t = t_tmp + mn_high;
- u256 overflows_r = t < t_tmp;
- u256 overflows_modulus = t >= N;
-
- return t + overflows_r * R_SUB_N - overflows_modulus * N;
- }
-};
diff --git a/mopro-msm/src/msm/metal/shader/helper/bigint_to_hex.py b/mopro-msm/src/msm/metal/shader/helper/bigint_to_hex.py
deleted file mode 100644
index 24c5765a..00000000
--- a/mopro-msm/src/msm/metal/shader/helper/bigint_to_hex.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Big integer to convert
-big_int = 21888242871839275222246405745257275088696311157297823662689037894645226208583
-
-# Step 1: Convert the big integer to a hexadecimal string
-hex_str = hex(big_int)[2:] # Removing the '0x' prefix
-
-# Step 2: Ensure the hex string length is a multiple of 8 by padding with leading zeros if necessary
-if len(hex_str) % 8 != 0:
- hex_str = hex_str.zfill((len(hex_str) // 8 + 1) * 8)
-
-# Step 3: Split the hex string into chunks of 8 characters (Big Endian order)
-limbs = [hex_str[i:i+8].upper() for i in range(0, len(hex_str), 8)]
-
-# Print the results in Big Endian order
-print("Decimal Integer:", big_int)
-print("Hexadecimal String:", hex_str)
-print("32-bit unsigned integer limbs in hex format (Big Endian):")
-
-# \n for every two limbs
-for i in range(0, len(limbs), 2):
- print("0x" + limbs[i] + ", 0x" + limbs[i+1] + ",")
diff --git a/mopro-msm/src/msm/metal/shader/helper/bn254_addchain.acc b/mopro-msm/src/msm/metal/shader/helper/bn254_addchain.acc
deleted file mode 100644
index bba8b662..00000000
--- a/mopro-msm/src/msm/metal/shader/helper/bn254_addchain.acc
+++ /dev/null
@@ -1,41 +0,0 @@
-_10 = 2*1
-_11 = 1 + _10
-_101 = _10 + _11
-_110 = 1 + _101
-_1000 = _10 + _110
-_1101 = _101 + _1000
-_10010 = _101 + _1101
-_10011 = 1 + _10010
-_10100 = 1 + _10011
-_10111 = _11 + _10100
-_11100 = _101 + _10111
-_100000 = _1101 + _10011
-_100011 = _11 + _100000
-_101011 = _1000 + _100011
-_101111 = _10011 + _11100
-_1000001 = _10010 + _101111
-_1010011 = _10010 + _1000001
-_1011011 = _1000 + _1010011
-_1100001 = _110 + _1011011
-_1110101 = _10100 + _1100001
-_10010001 = _11100 + _1110101
-_10010101 = _100000 + _1110101
-_10110101 = _100000 + _10010101
-_10111011 = _110 + _10110101
-_11000001 = _110 + _10111011
-_11000011 = _10 + _11000001
-_11010011 = _10010 + _11000001
-_11100001 = _100000 + _11000001
-_11100011 = _10 + _11100001
-_11100111 = _110 + _11100001
-i57 = ((_11000001 << 8 + _10010001) << 10 + _11100111) << 7
-i76 = ((_10111 + i57) << 9 + _10011) << 7 + _1101
-i109 = ((i76 << 14 + _1010011) << 9 + _11100001) << 8
-i127 = ((_1000001 + i109) << 10 + _1011011) << 5 + _1101
-i161 = ((i127 << 8 + _11) << 12 + _101011) << 12
-i186 = ((_10111011 + i161) << 8 + _101111) << 14 + _10110101
-i214 = ((i186 << 9 + _10010001) << 5 + _1101) << 12
-i236 = ((_11100011 + i214) << 8 + _10010101) << 11 + _11010011
-i268 = ((i236 << 7 + _1100001) << 11 + _100011) << 12
-i288 = ((_1011011 + i268) << 9 + _11000011) << 8 + _11100111
-return (i288 << 7 + _1110101) << 6 + _101
diff --git a/mopro-msm/src/msm/metal/shader/helper/mu.py b/mopro-msm/src/msm/metal/shader/helper/mu.py
deleted file mode 100644
index a5c5e796..00000000
--- a/mopro-msm/src/msm/metal/shader/helper/mu.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from sympy import mod_inverse
-
-N = 21888242871839275222246405745257275088696311157297823662689037894645226208583
-mod = 1 << 32
-
-# Step 1: N^-1 mod 2^32
-try:
- N_inv = mod_inverse(N, mod)
-except ValueError as e:
- print(e)
- N_inv = None
-
-# Step 2: Compute -N^-1 mod 2^32
-if N_inv is not None:
- result = mod - N_inv
-
-# Step 3: Make sure N_inv is computed correctly
-if N_inv is not None:
- # compute N * -N^-1 mod 2^32 = -1
- print("Check that N * N^-1 mod 2^32:", (N * result) % mod)
- assert (N * result) % mod == mod - 1
-
-# Step 4: Convert the result to a hexadecimal string
-hex_str = hex(result)[2:] # Removing the '0x' prefix
-
-# Step 5: Ensure the hex string length is a multiple of 8 by padding with leading zeros if necessary
-if len(hex_str) % 8 != 0:
- hex_str = hex_str.zfill((len(hex_str) // 8 + 1) * 8)
-
-# Step 6: Split the hex string into chunks of 8 characters (Big Endian order)
-limbs = [hex_str[i:i+8].upper() for i in range(0, len(hex_str), 8)]
-
-# Print the results in Big Endian order
-print(f"Result in Hexadecimal: 0x{hex_str.upper()}FFFFFFFF")
-print(f"Result in Decimal: {result}")
\ No newline at end of file
diff --git a/mopro-msm/src/msm/metal/shader/helper/r_sqr_mod_n.py b/mopro-msm/src/msm/metal/shader/helper/r_sqr_mod_n.py
deleted file mode 100644
index 9f526ede..00000000
--- a/mopro-msm/src/msm/metal/shader/helper/r_sqr_mod_n.py
+++ /dev/null
@@ -1,23 +0,0 @@
-R = (1 << 256)
-N = 21888242871839275222246405745257275088696311157297823662689037894645226208583
-
-# Step 1: Compute R^2 mod N
-result = R**2 % N
-
-# Step 2: Convert the result to a hexadecimal string
-hex_str = hex(result)[2:] # Removing the '0x' prefix
-
-# Step 3: Ensure the hex string length is a multiple of 8 by padding with leading zeros if necessary
-if len(hex_str) % 8 != 0:
- hex_str = hex_str.zfill((len(hex_str) // 8 + 1) * 8)
-
-# Step 4: Split the hex string into chunks of 8 characters (Big Endian order)
-limbs = [hex_str[i:i+8].upper() for i in range(0, len(hex_str), 8)]
-
-# Print the results in Big Endian order
-print("Hexadecimal String:", hex_str)
-print("32-bit unsigned integer limbs in hex format (Big Endian):")
-
-# \n for every two limbs
-for i in range(0, len(limbs), 2):
- print("0x" + limbs[i] + ", 0x" + limbs[i+1] + ",")
diff --git a/mopro-msm/src/msm/metal/shader/helper/r_sub_n.py b/mopro-msm/src/msm/metal/shader/helper/r_sub_n.py
deleted file mode 100644
index 72ce00db..00000000
--- a/mopro-msm/src/msm/metal/shader/helper/r_sub_n.py
+++ /dev/null
@@ -1,23 +0,0 @@
-R = (1 << 256)
-N = 21888242871839275222246405745257275088696311157297823662689037894645226208583
-
-# Step 1: Compute R - N
-result = R - N
-
-# Step 2: Convert the result to a hexadecimal string
-hex_str = hex(result)[2:] # Removing the '0x' prefix
-
-# Step 3: Ensure the hex string length is a multiple of 8 by padding with leading zeros if necessary
-if len(hex_str) % 8 != 0:
- hex_str = hex_str.zfill((len(hex_str) // 8 + 1) * 8)
-
-# Step 4: Split the hex string into chunks of 8 characters (Big Endian order)
-limbs = [hex_str[i:i+8].upper() for i in range(0, len(hex_str), 8)]
-
-# Print the results in Big Endian order
-print("Hexadecimal String:", hex_str)
-print("32-bit unsigned integer limbs in hex format (Big Endian):")
-
-# \n for every two limbs
-for i in range(0, len(limbs), 2):
- print("0x" + limbs[i] + ", 0x" + limbs[i+1] + ",")
diff --git a/mopro-msm/src/msm/metal/shader/helper/requirements.txt b/mopro-msm/src/msm/metal/shader/helper/requirements.txt
deleted file mode 100644
index ded0ee75..00000000
--- a/mopro-msm/src/msm/metal/shader/helper/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-sympy
diff --git a/mopro-msm/src/msm/metal/shader/msm.h.metal b/mopro-msm/src/msm/metal/shader/msm.h.metal
deleted file mode 100644
index 61c5b675..00000000
--- a/mopro-msm/src/msm/metal/shader/msm.h.metal
+++ /dev/null
@@ -1,252 +0,0 @@
-#pragma once
-
-#include "curves/bn254.h.metal"
-#include "fields/fp_bn254.h.metal"
-#include "arithmetics/unsigned_int.h.metal"
-
-namespace {
- typedef UnsignedInteger<8> u256;
- typedef FpBN254 FieldElement;
- typedef ECPoint Point;
-}
-
-constant constexpr uint32_t NUM_LIMBS = 8; // u256
-
-[[kernel]] void initialize_buckets(
- constant const uint32_t& _window_size [[ buffer(0) ]],
- constant const uint32_t* _window_starts [[ buffer(1) ]],
- device Point* buckets_matrix [[ buffer(2) ]],
- const uint32_t thread_id [[ thread_position_in_grid ]],
- const uint32_t total_threads [[ threads_per_grid ]]
-)
-{
- if (thread_id >= total_threads) {
- return;
- }
-
- uint32_t window_size = _window_size; // c in arkworks code
- uint32_t window_idx = _window_starts[thread_id];
- uint32_t buckets_len = (1 << window_size) - 1;
-
- for (uint32_t i = 0; i < buckets_len; i++) {
- buckets_matrix[window_idx + i] = Point::neutral_element();
- }
-}
-
-[[kernel]] void accumulation_and_reduction_phase(
- constant const uint32_t& _window_size [[ buffer(0) ]],
- constant const uint32_t& _instances_size [[ buffer(1) ]],
- constant const uint32_t* _window_starts [[ buffer(2) ]],
- constant const u256* k_buff [[ buffer(3) ]],
- constant const Point* p_buff [[ buffer(4) ]],
- device Point* buckets_matrix [[ buffer(5) ]],
- device Point* res [[ buffer(6) ]],
- const uint32_t thread_id [[ thread_position_in_grid ]],
- const uint32_t total_threads [[ threads_per_grid ]]
-)
-{
- if (thread_id >= total_threads) {
- return;
- }
-
- uint32_t window_size = _window_size; // c in arkworks code
- uint32_t instances_size = _instances_size;
- uint32_t buckets_len = (1 << window_size) - 1;
- uint32_t window_idx = _window_starts[thread_id];
-
- u256 one = u256::from_int((uint32_t)1);
- res[thread_id] = Point::neutral_element();
-
- // for each points and scalars, calculate the bucket index and accumulate
- for (uint32_t i = 0; i < instances_size; i++) {
- u256 k = k_buff[i];
- Point p = p_buff[i];
- // pass if k is one
- if (k == one) {
- if (window_idx == 0) {
- Point this_res = res[thread_id];
- res[thread_id] = this_res + p_buff[i];
- }
- }
- else {
- // move the b-bit scalar to the correct c-bit scalar fragment
- uint32_t scalar_fragment = (k >> window_idx).m_limbs[NUM_LIMBS - 1];
- // truncate the scalar fragment to the window size
- uint32_t m_ij = scalar_fragment & buckets_len;
-
- if (m_ij != 0) {
- uint32_t idx = m_ij - 1;
- Point bucket = buckets_matrix[thread_id * buckets_len + idx];
- buckets_matrix[thread_id * buckets_len + idx] = bucket + p;
- }
- }
- }
-
- // Perform sum reduction on buckets
- // Get the last bucket of this window
- uint32_t last_bucket_idx = (thread_id + 1) * buckets_len - 1;
-
- Point running_sum = Point::neutral_element();
- for (uint32_t i = 0; i < buckets_len; i++) {
- running_sum = running_sum + buckets_matrix[last_bucket_idx - i];
- Point this_res = res[thread_id];
- res[thread_id] = this_res + running_sum;
- }
-}
-
-// instance-wise parallel
-[[kernel]] void prepare_buckets_indices(
- constant const uint32_t& _window_size [[ buffer(0) ]],
- constant const uint32_t* _window_starts [[ buffer(1) ]],
- constant const uint32_t& _num_windows [[ buffer(2) ]],
- constant const u256* k_buff [[ buffer(3) ]],
- device uint2* buckets_indices [[ buffer(4) ]],
- const uint32_t thread_id [[ thread_position_in_grid ]],
- const uint32_t total_threads [[ threads_per_grid ]]
-)
-{
- if (thread_id >= total_threads) {
- return;
- }
-
- uint32_t window_size = _window_size; // c in arkworks code
- uint32_t num_windows = _num_windows;
- uint32_t buckets_len = (1 << window_size) - 1;
- u256 this_scalar = k_buff[thread_id];
-
- // skip if the scalar is uint scalar
- u256 one = u256::from_int((uint32_t)1);
- if (this_scalar == one) {
- return;
- }
-
- // for each window, record the corresponding bucket index and point idx
- for (uint32_t i = 0; i < num_windows; i++) {
- uint32_t window_idx = _window_starts[i];
-
- uint32_t scalar_fragment = (this_scalar >> window_idx).m_limbs[NUM_LIMBS - 1];
- uint32_t m_ij = scalar_fragment & buckets_len;
-
- // the case (b_idx, p_idx) = (0, 0) is not possible
- // since thread_id == 0 && i == 0 && m_ij == 1 is not possible
- if (m_ij != 0) {
- uint32_t bucket_idx = i * buckets_len + m_ij - 1;
- uint32_t point_idx = thread_id;
- buckets_indices[thread_id * num_windows + i] = uint2(bucket_idx, point_idx);
- }
- }
-}
-
-// TODO: sorting buckets_indices with bucket_idx as key
-
-[[kernel]] void bucket_wise_accumulation(
- constant const uint32_t& _instances_size [[ buffer(0) ]],
- constant const uint32_t& _num_windows [[ buffer(1) ]],
- constant const Point* p_buff [[ buffer(2) ]],
- device uint2* buckets_indices [[ buffer(3) ]],
- device Point* buckets_matrix [[ buffer(4) ]],
- constant const uint32_t& _max_thread_size [[ buffer(5) ]],
- uint2 dispatch_threads_per_threadgroup [[ dispatch_threads_per_threadgroup ]],
- uint2 threadgroup_position_in_grid [[ threadgroup_position_in_grid ]],
- uint tid [[ thread_index_in_threadgroup ]]
-)
-{
- uint max_threads_per_threadgroup = dispatch_threads_per_threadgroup.x * dispatch_threads_per_threadgroup.y;
- uint gid = threadgroup_position_in_grid.x;
- uint thread_id = gid * max_threads_per_threadgroup + tid;
-
- uint32_t max_thread_size = _max_thread_size;
- if (thread_id >= max_thread_size) {
- return;
- }
-
- uint32_t num_windows = _num_windows;
- uint32_t instances_size = _instances_size;
-
- uint32_t bucket_start_idx = 0;
- uint32_t max_idx = num_windows * instances_size;
-
- while (thread_id != buckets_indices[bucket_start_idx].x && bucket_start_idx < max_idx) {
- bucket_start_idx++;
- }
- // return if the bucket needs no accumulation
- if (bucket_start_idx == max_idx) {
- return;
- }
-
- // perform bucket-wise accumulation
- while (thread_id == buckets_indices[bucket_start_idx].x && bucket_start_idx < max_idx) {
- Point p = buckets_matrix[thread_id];
- buckets_matrix[thread_id] = p + p_buff[buckets_indices[bucket_start_idx].y];
- bucket_start_idx++;
- }
-}
-
-// window-wise reduction
-[[kernel]] void sum_reduction(
- constant const uint32_t& _window_size [[ buffer(0) ]],
- constant const u256* k_buff [[ buffer(1) ]],
- constant const Point* p_buff [[ buffer(2) ]],
- device Point* buckets_matrix [[ buffer(3) ]],
- device Point* res [[ buffer(4) ]],
- constant const uint32_t& _max_thread_size [[ buffer(5) ]],
- const uint32_t thread_id [[ thread_index_in_threadgroup ]]
-)
-{
- uint32_t max_thread_size = _max_thread_size;
- if (thread_id >= max_thread_size) {
- return;
- }
-
- uint32_t window_size = _window_size; // c in arkworks code
- uint32_t buckets_len = (1 << window_size) - 1;
-
- u256 one = u256::from_int((uint32_t)1);
- res[thread_id] = Point::neutral_element();
-
- // get the res for the first window
- if (thread_id == 0) {
- u256 k = k_buff[thread_id];
- if (k == one) {
- Point this_res = res[thread_id];
- res[thread_id] = this_res + p_buff[thread_id];
- }
- }
-
- // Perform sum reduction on buckets
- // Get the last bucket of this window
- uint32_t last_bucket_idx = (thread_id + 1) * buckets_len - 1;
-
- Point running_sum = Point::neutral_element();
- for (uint32_t i = 0; i < buckets_len; i++) {
- running_sum = running_sum + buckets_matrix[last_bucket_idx - i];
- Point this_res = res[thread_id];
- res[thread_id] = this_res + running_sum;
- }
-}
-
-
-[[kernel]] void final_accumulation(
- constant const uint32_t& _window_size [[ buffer(0) ]],
- constant const uint32_t* _window_starts [[ buffer(1) ]],
- constant const uint32_t& _num_windows [[ buffer(2) ]],
- device Point* res [[ buffer(3) ]],
- device Point& msm_result [[ buffer(4) ]]
-)
-{
- uint32_t window_size = _window_size; // c in arkworks code
- uint32_t num_windows = _num_windows;
- Point lowest_window_sum = res[0];
- uint32_t last_res_idx = num_windows - 1;
-
- Point total_sum = Point::neutral_element();
- for (uint32_t i = 1; i < num_windows; i++) {
- Point tmp = total_sum;
- total_sum = tmp + res[last_res_idx - i + 1];
-
- for (uint32_t j = 0; j < window_size; j++) {
- total_sum = total_sum.double_in_place();
- }
- }
- msm_result = total_sum + lowest_window_sum;
-}
diff --git a/mopro-msm/src/msm/metal/shader/tests/test_bn254.h.metal b/mopro-msm/src/msm/metal/shader/tests/test_bn254.h.metal
deleted file mode 100644
index 313c565e..00000000
--- a/mopro-msm/src/msm/metal/shader/tests/test_bn254.h.metal
+++ /dev/null
@@ -1,85 +0,0 @@
-#pragma once
-
-#include "../fields/fp_bn254.h.metal"
-
-using namespace metal;
-
-template
-[[kernel]] void bn254_add(
- constant Fp* p [[ buffer(0) ]],
- constant Fp* q [[ buffer(1) ]],
- device Fp* result [[ buffer(2) ]]
-)
-{
- BN254 P = BN254(p[0], p[1], p[2]);
- BN254 Q = BN254(q[0], q[1], q[2]);
- BN254 res = P + Q;
-
- result[0] = res.x;
- result[1] = res.y;
- result[2] = res.z;
-}
-
-template
-[[kernel]] void fp_bn254_add(
- constant FpBN254& _p [[ buffer(0) ]],
- constant FpBN254& _q [[ buffer(1) ]],
- device FpBN254& result [[ buffer(2) ]]
-) {
- FpBN254 p = _p;
- FpBN254 q = _q;
- result = p + q;
-}
-
-template
-[[kernel]] void fp_bn254_sub(
- constant FpBN254 &_p [[ buffer(0) ]],
- constant FpBN254 &_q [[ buffer(1) ]],
- device FpBN254 &result [[ buffer(2) ]]
-) {
- FpBN254 p = _p;
- FpBN254 q = _q;
- result = p - q;
-}
-
-template
-[[kernel]] void fp_bn254_mul(
- constant FpBN254 &_p [[ buffer(0) ]],
- constant FpBN254 &_q [[ buffer(1) ]],
- device FpBN254 &result [[ buffer(2) ]]
-) {
- FpBN254 p = _p;
- FpBN254 q = _q;
- result = p * q;
-}
-
-template
-[[kernel]] void fp_bn254_pow(
- constant FpBN254 &_p [[ buffer(0) ]],
- constant uint32_t &_a [[ buffer(1) ]],
- device FpBN254 &result [[ buffer(2) ]]
-) {
- FpBN254 p = _p;
- result = p.pow(_a);
-}
-
-template
-[[kernel]] void fp_bn254_neg(
- constant FpBN254 &_p [[ buffer(0) ]],
- constant uint32_t &_a [[ buffer(1) ]], // TODO: Remove this dummy arg
- device FpBN254 &result [[ buffer(2) ]]
-) {
- FpBN254 p = _p;
- result = p.neg();
-}
-
-// // TODO: Implement inverse if needed in the future
-// [[kernel]] void fp_bn254_inv(
-// constant FpBN254 &_p [[ buffer(0) ]],
-// constant FpBN254 &_q [[ buffer(1) ]],
-// device FpBN254 &result [[ buffer(2) ]]
-// ) {
-// FpBN254 p = _p;
-// FpBN254 inv_p = p.inverse();
-// result = inv_p;
-// }
diff --git a/mopro-msm/src/msm/metal/shader/tests/test_unsigned_integer.h.metal b/mopro-msm/src/msm/metal/shader/tests/test_unsigned_integer.h.metal
deleted file mode 100644
index 3b8f447b..00000000
--- a/mopro-msm/src/msm/metal/shader/tests/test_unsigned_integer.h.metal
+++ /dev/null
@@ -1,86 +0,0 @@
-#pragma once
-
-#include "../arithmetics/unsigned_int.h.metal"
-
-namespace {
- typedef UnsignedInteger<8> u256;
-}
-
-[[kernel]]
-void test_uint_add(
- constant u256& _a [[ buffer(0) ]],
- constant u256& _b [[ buffer(1) ]],
- device u256& result [[ buffer(2) ]])
-{
- u256 a = _a;
- u256 b = _b;
-
- result = a + b;
-}
-
-[[kernel]]
-void test_uint_sub(
- constant u256& _a [[ buffer(0) ]],
- constant u256& _b [[ buffer(1) ]],
- device u256& result [[ buffer(2) ]])
-{
- u256 a = _a;
- u256 b = _b;
-
- result = a - b;
-}
-
-[[kernel]]
-void test_uint_prod(
- constant u256& _a [[ buffer(0) ]],
- constant u256& _b [[ buffer(1) ]],
- device u256& result [[ buffer(2) ]])
-{
- u256 a = _a;
- u256 b = _b;
-
- result = a * b;
-}
-
-[[kernel]]
-void test_uint_shl(
- constant u256& _a [[ buffer(0) ]],
- constant uint32_t& _b [[ buffer(1) ]],
- device u256& result [[ buffer(2) ]])
-{
- u256 a = _a;
- uint32_t b = _b;
-
- result = a << b;
-}
-
-// [[kernel]]
-// void test_uint_shl(
-// constant u256& _a [[ buffer(0) ]],
-// constant uint32_t& _b [[ buffer(1) ]],
-// device u256& result [[ buffer(2) ]],
-// device uint32_t* debugBuffer [[ buffer(3) ]])
-// {
-// u256 a = _a;
-// uint32_t b = _b;
-
-// result = a << b;
-// // Write the values of a and b to the debug buffer
-// for (int i = 0; i < 8; ++i) {
-// debugBuffer[i] = a.m_limbs[i];
-// debugBuffer[i + 8] = b;
-// debugBuffer[i + 16] = result.m_limbs[i];
-// }
-// }
-
-[[kernel]]
-void test_uint_shr(
- constant u256& _a [[ buffer(0) ]],
- constant uint32_t& _b [[ buffer(1) ]],
- device u256& result [[ buffer(2) ]])
-{
- u256 a = _a;
- uint32_t b = _b;
-
- result = a >> b;
-}
diff --git a/mopro-msm/src/msm/metal/tests/mod.rs b/mopro-msm/src/msm/metal/tests/mod.rs
deleted file mode 100644
index adeb483e..00000000
--- a/mopro-msm/src/msm/metal/tests/mod.rs
+++ /dev/null
@@ -1 +0,0 @@
-pub mod test_bn254;
diff --git a/mopro-msm/src/msm/metal/tests/test_bn254.rs b/mopro-msm/src/msm/metal/tests/test_bn254.rs
deleted file mode 100644
index 2c42abca..00000000
--- a/mopro-msm/src/msm/metal/tests/test_bn254.rs
+++ /dev/null
@@ -1,425 +0,0 @@
-#[cfg(all(test))]
-mod tests {
- use crate::msm::metal::abstraction::{
- limbs_conversion::{FromLimbs, ToLimbs},
- state::MetalState,
- };
-
- use ark_bn254::{Fq, G1Projective as G};
- use ark_ff::{BigInt, Field};
- use ark_std::Zero;
-
- use metal::MTLSize;
- use proptest::prelude::*;
-
- pub type FE = Fq; // Field Element
-
- mod unsigned_int_tests {
- use super::*;
-
- enum BigOrSmallInt {
- Big(BigInteger256),
- Small(usize),
- }
-
- fn execute_kernel(name: &str, params: (BigInteger256, BigOrSmallInt)) -> BigInteger256 {
- let state = MetalState::new(None).unwrap();
- let pipeline = state.setup_pipeline(name).unwrap();
-
- let (a, b) = params;
-
- let a = a.to_u32_limbs();
-
- let result_buffer = state.alloc_buffer::(1);
-
- let debug_buffer = state.alloc_buffer::(24);
-
- let (command_buffer, command_encoder) = match b {
- BigOrSmallInt::Big(b) => {
- let b = b.to_u32_limbs();
- let a_buffer = state.alloc_buffer_data(&a);
- let b_buffer = state.alloc_buffer_data(&b);
- state.setup_command(
- &pipeline,
- Some(&[
- (0, &a_buffer),
- (1, &b_buffer),
- (2, &result_buffer),
- (3, &debug_buffer),
- ]),
- )
- }
- BigOrSmallInt::Small(b) => {
- let a_buffer = state.alloc_buffer_data(&a);
- let b_buffer = state.alloc_buffer_data(&[b]);
- state.setup_command(
- &pipeline,
- Some(&[
- (0, &a_buffer),
- (1, &b_buffer),
- (2, &result_buffer),
- (3, &debug_buffer),
- ]),
- )
- }
- };
-
- let threadgroup_size = MTLSize::new(1, 1, 1);
- let threadgroup_count = MTLSize::new(1, 1, 1);
-
- command_encoder.dispatch_thread_groups(threadgroup_count, threadgroup_size);
- command_encoder.end_encoding();
-
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- let limbs = MetalState::retrieve_contents::(&result_buffer);
-
- BigInteger256::from_u32_limbs(&limbs)
- }
-
- prop_compose! {
- fn rand_u128()(n in any::()) -> BigInteger256 { BigInteger256::from_u128(n) }
- }
- prop_compose! {
- fn rand_u32()(n in any::()) -> BigInteger256 { BigInteger256::from_u32(n) }
- }
-
- use ark_ff::biginteger::{BigInteger, BigInteger256};
- use num_bigint::BigUint;
-
- use BigOrSmallInt::{Big, Small};
-
- proptest! {
- #[test]
- fn add(a in rand_u128(), b in rand_u128()) {
- let mut result = BigInteger256::default();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("test_uint_add", (a, Big(b)));
- });
- let mut local_add = a;
- local_add.add_with_carry(&b);
- prop_assert_eq!(result, local_add);
- }
-
- #[test]
- fn sub(a in rand_u128(), b in rand_u128()) {
- let mut result = BigInteger256::default();
- let (a, b) = if a < b { (b, a) } else { (a, b) };
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("test_uint_sub", (a, Big(b)));
- });
- let mut local_sub = a;
- local_sub.sub_with_borrow(&b);
- prop_assert_eq!(result, local_sub);
- }
-
- #[test]
- fn prod(a in rand_u128(), b in rand_u32()) {
- let mut result = BigInteger256::default();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("test_uint_prod", (a, Big(b)));
- });
- let local_prod = BigUint::from(a) * BigUint::from(b);
- let mut base_bigint: [u64; 4] = [0; 4];
- for (i, limb) in local_prod.to_u64_digits().iter().enumerate() {
- base_bigint[i] = *limb;
- }
- let local_prod: BigInt<4> = BigInt(base_bigint);
- prop_assert_eq!(result, local_prod);
- }
-
- #[test]
- fn shl(a in rand_u128(), b in any::()) {
- let mut result = BigInteger256::default();
- let b = b % 256; // so it doesn't overflow
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("test_uint_shl", (a, Small(b)));
- });
- let mut local_shl = a;
- local_shl.muln(b as u32);
- prop_assert_eq!(result, local_shl);
- }
-
- #[test]
- fn shr(a in rand_u128(), b in any::()) {
- let mut result = BigInteger256::default();
- let b = b % 256; // so it doesn't overflow
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("test_uint_shr", (a, Small(b)));
- });
- let mut local_shr = a;
- local_shr.divn(b as u32);
- prop_assert_eq!(result, local_shr);
- }
- }
- }
-
- mod fp_tests {
- use super::*;
-
- use proptest::collection;
-
- enum FEOrInt {
- Elem(FE),
- Int(u32),
- }
-
- fn execute_kernel(name: &str, a: &FE, b: FEOrInt) -> FE {
- let state = MetalState::new(None).unwrap();
- let pipeline = state.setup_pipeline(name).unwrap();
-
- let a = a.to_u32_limbs();
- let result_buffer = state.alloc_buffer::(8);
-
- let (command_buffer, command_encoder) = match b {
- FEOrInt::Elem(b) => {
- let b = b.to_u32_limbs();
- let a_buffer = state.alloc_buffer_data(&a);
- let b_buffer = state.alloc_buffer_data(&b);
-
- state.setup_command(
- &pipeline,
- Some(&[(0, &a_buffer), (1, &b_buffer), (2, &result_buffer)]),
- )
- }
- FEOrInt::Int(b) => {
- let a_buffer = state.alloc_buffer_data(&a);
- let b_buffer = state.alloc_buffer_data(&[b]);
-
- state.setup_command(
- &pipeline,
- Some(&[(0, &a_buffer), (1, &b_buffer), (2, &result_buffer)]),
- )
- }
- };
-
- let threadgroup_size = MTLSize::new(1, 1, 1);
- let threadgroup_count = MTLSize::new(1, 1, 1);
-
- command_encoder.dispatch_thread_groups(threadgroup_count, threadgroup_size);
- command_encoder.end_encoding();
-
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- let limbs = MetalState::retrieve_contents::(&result_buffer);
- FE::from_u32_limbs(&limbs)
- }
-
- prop_compose! {
- fn rand_u32()(n in any::()) -> u32 { n }
- }
-
- prop_compose! {
- fn rand_limbs()(vec in collection::vec(rand_u32(), 8)) -> Vec {
- vec
- }
- }
-
- prop_compose! {
- fn rand_field_element()(limbs in rand_limbs()) -> FE {
- FE::from_u32_limbs(&limbs)
- }
- }
-
- use FEOrInt::{Elem, Int};
-
- proptest! {
- #[test]
- fn add(a in rand_field_element(), b in rand_field_element()) {
- let mut result = Fq::default();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("fp_bn254_add", &a, Elem(b.clone()));
- });
- let local_add = a + b;
- prop_assert_eq!(result, local_add);
- }
-
- #[test]
- fn sub(a in rand_field_element(), b in rand_field_element()) {
- let mut result = Fq::default();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("fp_bn254_sub", &a, Elem(b.clone()));
- });
- let local_sub = a - b;
- prop_assert_eq!(result, local_sub);
- }
-
- #[test]
- fn mul(a in rand_field_element(), b in rand_field_element()) {
- let mut result = Fq::default();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("fp_bn254_mul", &a, Elem(b.clone()));
- });
- let local_mul = a * b;
- prop_assert_eq!(result, local_mul);
- }
-
- #[test]
- fn neg(a in rand_field_element()) {
- let mut result = Fq::default();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("fp_bn254_neg", &a, Int(0));
- });
- let local_neg = -a;
- prop_assert_eq!(result, local_neg);
- }
-
- #[test]
- fn pow(a in rand_field_element(), b in rand_u32()) {
- let mut result = Fq::default();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("fp_bn254_pow", &a, Int(b));
- });
- let local_pow = a.pow(&[b as u64]);
- prop_assert_eq!(result, local_pow);
- }
-
- // // TODO: Implement inverse if needed in the future
- // #[test]
- // fn inv(a in rand_field_element()) {
- // let mut result = Fq::default();
- // objc::rc::autoreleasepool(|| {
- // result = execute_kernel("test_bn254_inv", &a, Int(0));
- // });
- // let local_inv = a.inverse().unwrap();
- // println!("a: {:?}", a.0);
- // println!("a inv: {:?}", local_inv.0);
- // println!("result: {:?}", result);
- // prop_assert_eq!(result.into_bigint(), local_inv.0);
- // }
- }
- }
-
- mod ec_tests {
- use ark_ff::UniformRand;
- use ark_std::rand::thread_rng;
-
- use super::*;
-
- fn point_to_u32_limbs(p: &G) -> Vec {
- p.x.to_u32_limbs()
- .into_iter()
- .chain(p.y.to_u32_limbs())
- .chain(p.z.to_u32_limbs())
- .collect()
- }
-
- fn execute_kernel(name: &str, p: &G, q: &G) -> Vec {
- let state = MetalState::new(None).unwrap();
- let pipeline = state.setup_pipeline(name).unwrap();
-
- let p_coordinates: Vec = point_to_u32_limbs(p);
- let q_coordinates: Vec = point_to_u32_limbs(q);
-
- let p_buffer = state.alloc_buffer_data(&p_coordinates);
- let q_buffer = state.alloc_buffer_data(&q_coordinates);
- let result_buffer = state.alloc_buffer::(24);
-
- let (command_buffer, command_encoder) = state.setup_command(
- &pipeline,
- Some(&[(0, &p_buffer), (1, &q_buffer), (2, &result_buffer)]),
- );
-
- let threadgroup_size = MTLSize::new(1, 1, 1);
- let threadgroup_count = MTLSize::new(1, 1, 1);
-
- command_encoder.dispatch_thread_groups(threadgroup_count, threadgroup_size);
- command_encoder.end_encoding();
-
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- MetalState::retrieve_contents::(&result_buffer)
- }
-
- prop_compose! {
- fn rand_u128()(n in any::()) -> u128 { n }
- }
-
- prop_compose! {
- fn rand_point()(_n in any::()) -> G {
- let rng = &mut thread_rng();
- G::rand(rng)
- }
- }
-
- proptest! {
- #[test]
- fn add(p in rand_point(), q in rand_point()) {
- let mut result = vec![];
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("bn254_add", &p, &q);
- });
- let gpu_result = G::new(
- Fq::from_u32_limbs(&result[0..8]),
- Fq::from_u32_limbs(&result[8..16]),
- Fq::from_u32_limbs(&result[16..24]),
- );
- let cpu_result = p + q;
- prop_assert_eq!(gpu_result, cpu_result);
- }
-
- #[test]
- fn add_with_self(p in rand_point()) {
- let mut result = vec![];
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("bn254_add", &p, &p);
- });
- let gpu_result = G::new(
- Fq::from_u32_limbs(&result[0..8]),
- Fq::from_u32_limbs(&result[8..16]),
- Fq::from_u32_limbs(&result[16..24]),
- );
- let cpu_result = p + p;
- prop_assert_eq!(gpu_result, cpu_result);
- }
-
- #[test]
- fn add_with_infinity_rhs(p in rand_point()) {
- let mut result = vec![];
- let infinity = G::zero();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("bn254_add", &p, &infinity);
- });
- let gpu_result = G::new(
- Fq::from_u32_limbs(&result[0..8]),
- Fq::from_u32_limbs(&result[8..16]),
- Fq::from_u32_limbs(&result[16..24]),
- );
- let cpu_result = p + infinity;
- prop_assert_eq!(gpu_result, cpu_result);
- }
-
- #[test]
- fn add_with_infinity_lhs(p in rand_point()) {
- let mut result = vec![];
- let infinity = G::zero();
- objc::rc::autoreleasepool(|| {
- result = execute_kernel("bn254_add", &infinity, &p);
- });
- let gpu_result = G::new(
- Fq::from_u32_limbs(&result[0..8]),
- Fq::from_u32_limbs(&result[8..16]),
- Fq::from_u32_limbs(&result[16..24]),
- );
- let cpu_result = infinity + p;
- prop_assert_eq!(gpu_result, cpu_result);
- }
- }
-
- #[test]
- fn infinity_plus_infinity_should_equal_infinity() {
- let infinity = G::zero();
- let result = execute_kernel("bn254_add", &infinity, &infinity);
- let gpu_result = G::new(
- Fq::from_u32_limbs(&result[0..8]),
- Fq::from_u32_limbs(&result[8..16]),
- Fq::from_u32_limbs(&result[16..24]),
- );
- let cpu_result = infinity + infinity;
- assert_eq!(gpu_result, cpu_result);
- }
- }
-}
diff --git a/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs b/mopro-msm/src/msm/metal_msm/host/metal_wrapper.rs
similarity index 100%
rename from mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs
rename to mopro-msm/src/msm/metal_msm/host/metal_wrapper.rs
diff --git a/mopro-msm/src/msm/metal_msm/host/mod.rs b/mopro-msm/src/msm/metal_msm/host/mod.rs
index 223e3615..ea4552c2 100644
--- a/mopro-msm/src/msm/metal_msm/host/mod.rs
+++ b/mopro-msm/src/msm/metal_msm/host/mod.rs
@@ -1,4 +1,5 @@
pub mod errors;
-// pub mod state;
pub mod gpu;
+pub mod metal_wrapper;
pub mod shader;
+pub mod shader_manager;
diff --git a/mopro-msm/src/msm/metal_msm/host/shader.rs b/mopro-msm/src/msm/metal_msm/host/shader.rs
index 988c2b97..d25308cf 100644
--- a/mopro-msm/src/msm/metal_msm/host/shader.rs
+++ b/mopro-msm/src/msm/metal_msm/host/shader.rs
@@ -47,11 +47,6 @@ pub fn write_constants(
let num_limbs_wide = num_limbs + 1;
let num_limbs_extra_wide = num_limbs * 2;
- // MSM instance params
- let chunk_size = 16;
- let num_columns = 2u32.pow(chunk_size);
- let num_subtasks = (256 as f32 / chunk_size as f32).ceil() as u32;
-
let basefield_modulus = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size);
let r = calc_mont_radix(num_limbs, log_limb_size);
let p: BigUint = BaseField::MODULUS.try_into().unwrap();
@@ -77,9 +72,6 @@ pub fn write_constants(
data += format!("#define N0 {}\n", n0).as_str();
data += format!("#define NSAFE {}\n", nsafe).as_str();
data += format!("#define SLACK {}\n", slack).as_str();
- data += format!("#define CHUNK_SIZE {}\n", chunk_size).as_str();
- data += format!("#define NUM_COLUMNS {}\n", num_columns).as_str();
- data += format!("#define NUM_SUBTASKS {}\n", num_subtasks).as_str();
let mu_limbs = mu_in_ark.to_limbs(num_limbs, log_limb_size);
write_constant_array!(data, "BARRETT_MU", mu_limbs, "NUM_LIMBS");
diff --git a/mopro-msm/src/msm/metal_msm/utils/shader_manager.rs b/mopro-msm/src/msm/metal_msm/host/shader_manager.rs
similarity index 99%
rename from mopro-msm/src/msm/metal_msm/utils/shader_manager.rs
rename to mopro-msm/src/msm/metal_msm/host/shader_manager.rs
index 6ee15c0a..b1c69a70 100644
--- a/mopro-msm/src/msm/metal_msm/utils/shader_manager.rs
+++ b/mopro-msm/src/msm/metal_msm/host/shader_manager.rs
@@ -1,6 +1,6 @@
use crate::msm::metal_msm::host::gpu::get_default_device;
// no runtime codegen of shader constants
-use crate::msm::metal_msm::utils::metal_wrapper::{
+use crate::msm::metal_msm::host::metal_wrapper::{
get_or_calc_constants, MSMConstants, MetalConfig,
};
use metal::*;
diff --git a/mopro-msm/src/msm/metal_msm/host/state.rs b/mopro-msm/src/msm/metal_msm/host/state.rs
deleted file mode 100644
index 8a7cdde9..00000000
--- a/mopro-msm/src/msm/metal_msm/host/state.rs
+++ /dev/null
@@ -1,118 +0,0 @@
-use metal::{ComputeCommandEncoderRef, MTLResourceOptions};
-
-use super::errors::MetalError;
-
-const LIB_DATA: &[u8] = include_bytes!("../shader/msm.metallib");
-
-/// Structure for abstracting basic calls to a Metal device and saving the state. Used for
-/// implementing GPU parallel computations in Apple machines.
-pub struct MetalState {
- pub device: metal::Device,
- pub library: metal::Library,
- pub queue: metal::CommandQueue,
-}
-
-impl MetalState {
- /// Creates a new Metal state with an optional `device` (GPU). If `None` is passed then it will use
- /// the system's default.
- pub fn new(device: Option) -> Result {
- let device: metal::Device =
- device.unwrap_or(metal::Device::system_default().ok_or(MetalError::DeviceNotFound())?);
-
- let library = device
- .new_library_with_data(LIB_DATA) // TODO: allow different files
- .map_err(MetalError::LibraryError)?;
- let queue = device.new_command_queue();
-
- Ok(Self {
- device,
- library,
- queue,
- })
- }
-
- /// Creates a pipeline based on a compute function `kernel` which needs to exist in the state's
- /// library. A pipeline is used for issuing commands to the GPU through command buffers,
- /// executing the `kernel` function.
- pub fn setup_pipeline(
- &self,
- kernel_name: &str,
- ) -> Result {
- let kernel = self
- .library
- .get_function(kernel_name, None)
- .map_err(MetalError::FunctionError)?;
-
- let pipeline = self
- .device
- .new_compute_pipeline_state_with_function(&kernel)
- .map_err(MetalError::PipelineError)?;
-
- Ok(pipeline)
- }
-
- /// Allocates `length` bytes of shared memory between CPU and the device (GPU).
- pub fn alloc_buffer(&self, length: usize) -> metal::Buffer {
- let size = mem::size_of::();
-
- self.device.new_buffer(
- (length * size) as u64,
- MTLResourceOptions::StorageModeShared, // TODO: use managed mode
- )
- }
-
- /// Allocates `data` in a buffer of shared memory between CPU and the device (GPU).
- pub fn alloc_buffer_data(&self, data: &[T]) -> metal::Buffer {
- let size = mem::size_of::();
-
- self.device.new_buffer_with_data(
- data.as_ptr() as *const ffi::c_void,
- (data.len() * size) as u64,
- MTLResourceOptions::StorageModeShared, // TODO: use managed mode
- )
- }
-
- pub fn set_bytes(index: usize, data: &[T], encoder: &ComputeCommandEncoderRef) {
- let size = mem::size_of::();
-
- encoder.set_bytes(
- index as u64,
- (data.len() * size) as u64,
- data.as_ptr() as *const ffi::c_void,
- );
- }
-
- /// Creates a command buffer and a compute encoder in a pipeline, optionally issuing `buffers`
- /// to it.
- pub fn setup_command(
- &self,
- pipeline: &metal::ComputePipelineState,
- buffers: Option<&[(u64, &metal::Buffer)]>,
- ) -> (&metal::CommandBufferRef, &metal::ComputeCommandEncoderRef) {
- let command_buffer = self.queue.new_command_buffer();
- let command_encoder = command_buffer.new_compute_command_encoder();
- command_encoder.set_compute_pipeline_state(pipeline);
-
- if let Some(buffers) = buffers {
- for (i, buffer) in buffers.iter() {
- command_encoder.set_buffer(*i, Some(buffer), 0);
- }
- }
-
- (command_buffer, command_encoder)
- }
-
- /// Returns a vector of a copy of the data that `buffer` holds, interpreting it into a specific
- /// type `T`.
- ///
- /// BEWARE: this function uses an unsafe function for retrieveing the data, if the buffer's
- /// contents don't match the specified `T`, expect undefined behaviour. Always make sure the
- /// buffer you are retreiving from holds data of type `T`.
- pub fn retrieve_contents(buffer: &metal::Buffer) -> Vec {
- let ptr = buffer.contents() as *const T;
- let len = buffer.length() as usize / mem::size_of::();
- let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
-
- slice.to_vec()
- }
-}
diff --git a/mopro-msm/src/msm/metal_msm/metal_msm.rs b/mopro-msm/src/msm/metal_msm/metal_msm.rs
index d9175e2d..875350d9 100644
--- a/mopro-msm/src/msm/metal_msm/metal_msm.rs
+++ b/mopro-msm/src/msm/metal_msm/metal_msm.rs
@@ -1,11 +1,10 @@
+use crate::msm::metal_msm::host::metal_wrapper::{MetalConfig, MetalHelper};
+use crate::msm::metal_msm::host::shader_manager::{ShaderManager, ShaderManagerConfig, ShaderType};
use crate::msm::metal_msm::utils::limbs_conversion::{
pack_affine_and_scalars, GenericLimbConversion,
};
-use crate::msm::metal_msm::utils::metal_wrapper::{MetalConfig, MetalHelper};
use crate::msm::metal_msm::utils::mont_reduction::raw_reduction;
-use crate::msm::metal_msm::utils::shader_manager::{
- ShaderManager, ShaderManagerConfig, ShaderType,
-};
+use crate::msm::metal_msm::utils::window_size_optimizer::fetch_gpu_core_count_and_simd_width_from_device;
use ark_bn254::{Fq as BaseField, Fr as ScalarField, G1Affine as Affine, G1Projective as G};
use ark_ff::{BigInt, PrimeField};
use ark_std::{vec::Vec, Zero};
@@ -41,6 +40,8 @@ impl From for ShaderManagerConfig {
pub struct MetalMSMPipeline {
config: MetalMSMConfig,
shader_manager: ShaderManager,
+ gpu_cores: usize,
+ simd_width: usize,
}
impl MetalMSMPipeline {
@@ -48,9 +49,15 @@ impl MetalMSMPipeline {
let shader_config: ShaderManagerConfig = config.clone().into();
let shader_manager = ShaderManager::new(shader_config)?;
+ // Cache GPU core count once during initialization
+ let (gpu_cores, simd_width) =
+ fetch_gpu_core_count_and_simd_width_from_device(shader_manager.device());
+
Ok(Self {
config,
shader_manager,
+ gpu_cores,
+ simd_width,
})
}
@@ -58,11 +65,24 @@ impl MetalMSMPipeline {
Self::new(MetalMSMConfig::default())
}
+ /// Get the cached GPU core count
+ pub fn gpu_cores(&self) -> usize {
+ self.gpu_cores
+ }
+
/// Execute the complete MSM pipeline on GPU
- fn execute(&self, bases: &[Affine], scalars: &[ScalarField]) -> Result> {
- let input_size = bases.len();
- let num_subtasks = 256 / self.config.log_limb_size as usize;
- let num_columns = 1 << self.config.log_limb_size;
+ fn execute_pipeline(
+ &self,
+ bases: &[Affine],
+ scalars: &[ScalarField],
+ input_size: usize,
+ window_size: usize,
+ scale_factor: usize,
+ ) -> Result> {
+ // these params are set based on the Alogorithm 3 in the cuZK paper (https://eprint.iacr.org/2022/1321.pdf)
+ let num_columns = 1 << window_size;
+ let num_subtasks =
+ (ScalarField::MODULUS_BIT_SIZE as f32 / window_size as f32).ceil() as usize;
// Stage 0: Pack inputs
let metal_config = MetalConfig {
@@ -75,33 +95,18 @@ impl MetalMSMPipeline {
// Stage 1: Convert Point & Scalar Decomposition
let stage1 = ConvertPointAndScalarDecompose::new(&self.shader_manager);
- let mut c_workgroup_size = 256;
- let mut c_num_x_workgroups = 64;
- let mut c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
- let c_num_z_workgroups = 1;
- if input_size <= 256 {
- c_workgroup_size = input_size;
- c_num_x_workgroups = 1;
- c_num_y_workgroups = 1;
- } else if input_size > 256 && input_size <= 32768 {
- c_workgroup_size = 64;
- c_num_x_workgroups = 4;
- c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
- } else if input_size > 32768 && input_size <= 131072 {
- c_workgroup_size = 256;
- c_num_x_workgroups = 8;
- c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
- } else if input_size > 131072 && input_size < 1048576 {
- c_workgroup_size = 256;
- c_num_x_workgroups = 32;
- c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
- }
+ let c_workgroup_size = self.simd_width * scale_factor;
+ let c_num_x_workgroups = c_workgroup_size;
+ let c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
+ let c_num_z_workgroups = 1;
let (point_x, point_y, scalar_chunks) = stage1.execute(
&coords,
&scals,
input_size,
+ window_size,
+ num_columns,
num_subtasks,
c_num_x_workgroups,
c_num_y_workgroups,
@@ -110,12 +115,13 @@ impl MetalMSMPipeline {
)?;
// Stage 2: Transpose
+ let stage2 = Transpose::new(&self.shader_manager);
+
let t_num_x_workgroups = 1;
let t_num_y_workgroups = 1;
let t_num_z_workgroups = 1;
let t_workgroup_size = num_subtasks;
- let stage2 = Transpose::new(&self.shader_manager);
let (csc_col_ptr, csc_val_idxs) = stage2.execute(
&scalar_chunks,
num_subtasks,
@@ -128,12 +134,13 @@ impl MetalMSMPipeline {
)?;
// Stage 3: Sparse Matrix-Vector Multiplication
- let s_workgroup_size = 256;
- let s_num_x_workgroups = 64;
- let s_num_y_workgroups = 2;
- let s_num_z_workgroups = num_subtasks;
-
+ // according to the write-up (https://github.com/z-prize/2023-entries/tree/main/prize-2-msm-wasm/webgpu-only/tal-derei-koh-wei-jie#thread-count)
+ // the threads is (input size / 2) * num_subtasks
+ // after testing, 1D dim config is better than 3D dim config
let stage3 = SMVP::new(&self.shader_manager);
+
+ let s_workgroup_size = self.simd_width * scale_factor;
+
let (bucket_x, bucket_y, bucket_z) = stage3.execute(
&csc_col_ptr,
&csc_val_idxs,
@@ -142,26 +149,27 @@ impl MetalMSMPipeline {
input_size,
num_subtasks,
num_columns,
- s_num_x_workgroups,
- s_num_y_workgroups,
- s_num_z_workgroups,
+ self.simd_width,
s_workgroup_size,
)?;
// Stage 4: Parallel Bucket Reduction
- let num_subtasks_per_bpr_1 = 16;
- let num_subtasks_per_bpr_2 = 16;
+ // according to the write-up (https://github.com/z-prize/2023-entries/tree/main/prize-2-msm-wasm/webgpu-only/tal-derei-koh-wei-jie#thread-count)
+ // the threads is 2^k * num_subtasks
+ let stage4 = PBPR::new(&self.shader_manager);
+
+ let b_workgroup_size = self.simd_width * scale_factor;
+ let num_subtasks_per_bpr_1 = num_subtasks;
+ let num_subtasks_per_bpr_2 = num_subtasks;
let b_num_x_workgroups = num_subtasks_per_bpr_1;
let b_num_y_workgroups = 1;
let b_num_z_workgroups = 1;
- let b_workgroup_size = 256;
let b_2_num_x_workgroups = num_subtasks_per_bpr_2;
let b_2_num_y_workgroups = 1;
let b_2_num_z_workgroups = 1;
- let stage4 = PBPR::new(&self.shader_manager);
let (g_points_x, g_points_y, g_points_z) = stage4.execute(
&bucket_x,
&bucket_y,
@@ -179,27 +187,27 @@ impl MetalMSMPipeline {
b_workgroup_size,
)?;
- // Stage 5: Final reduction and Horner's method
+ // Stage 5 (on CPU): read buckets and perform final reduction with Horner's method
let result = self.final_reduction(
&g_points_x,
&g_points_y,
&g_points_z,
num_subtasks,
- self.config.log_limb_size as usize,
+ window_size,
b_workgroup_size,
)?;
Ok(result)
}
- /// Final reduction on CPU
+ /// Final reduction on CPU with Horner's method
fn final_reduction(
&self,
g_points_x: &[u32],
g_points_y: &[u32],
g_points_z: &[u32],
num_subtasks: usize,
- log_limb_size: usize,
+ window_size: usize,
pbpr_workgroup_size: usize,
) -> Result> {
// Parallel processing of subtasks
@@ -221,6 +229,7 @@ impl MetalMSMPipeline {
let yr_bigint = BigInt::<4>::from_limbs(yr_limbs, self.config.log_limb_size);
let zr_bigint = BigInt::<4>::from_limbs(zr_limbs, self.config.log_limb_size);
+ // decoded points from Montgomery form to standard form
let xr_reduced = raw_reduction(xr_bigint);
let yr_reduced = raw_reduction(yr_bigint);
let zr_reduced = raw_reduction(zr_bigint);
@@ -238,7 +247,7 @@ impl MetalMSMPipeline {
.collect();
// Horner's method
- let m = ScalarField::from(1u64 << log_limb_size);
+ let m = ScalarField::from(1u64 << window_size);
let mut result = gpu_points[gpu_points.len() - 1];
if gpu_points.len() > 1 {
@@ -267,6 +276,8 @@ impl<'a> ConvertPointAndScalarDecompose<'a> {
coords: &[u32],
scalars: &[u32],
input_size: usize,
+ window_size: usize,
+ num_columns: usize,
num_subtasks: usize,
c_num_x_workgroups: usize,
c_num_y_workgroups: usize,
@@ -288,8 +299,12 @@ impl<'a> ConvertPointAndScalarDecompose<'a> {
helper.create_empty_buffer(input_size * self.shader_manager.config().num_limbs);
let out_scalar_chunks = helper.create_empty_buffer(input_size * num_subtasks);
- let input_size_buf = helper.create_buffer(&vec![input_size as u32]);
- let num_y_workgroups_buf = helper.create_buffer(&vec![c_num_y_workgroups as u32]);
+ let params_buf = helper.create_buffer(&vec![
+ input_size as u32,
+ window_size as u32,
+ num_columns as u32,
+ num_subtasks as u32,
+ ]);
let thread_group_count = helper.create_thread_group_size(
c_num_x_workgroups as u64,
@@ -304,11 +319,10 @@ impl<'a> ConvertPointAndScalarDecompose<'a> {
&[
&coords_buf,
&scalars_buf,
- &input_size_buf,
&out_point_x,
&out_point_y,
&out_scalar_chunks,
- &num_y_workgroups_buf,
+ ¶ms_buf,
],
&thread_group_count,
&threads_per_threadgroup,
@@ -345,7 +359,7 @@ impl<'a> Transpose<'a> {
scalar_chunks: &[u32],
num_subtasks: usize,
input_size: usize,
- num_columns: u32,
+ num_columns: usize,
t_num_x_workgroups: usize,
t_num_y_workgroups: usize,
t_num_z_workgroups: usize,
@@ -363,8 +377,7 @@ impl<'a> Transpose<'a> {
let out_csc_val_idxs = helper.create_empty_buffer(scalar_chunks.len());
let out_curr = helper.create_empty_buffer(num_subtasks * (num_columns as usize) * 4);
- let params = vec![num_columns, input_size as u32];
- let params_buf = helper.create_buffer(¶ms);
+ let params_buf = helper.create_buffer(&vec![num_columns as u32, input_size as u32]);
let thread_group_count = helper.create_thread_group_size(
t_num_x_workgroups as u64,
@@ -417,10 +430,8 @@ impl<'a> SMVP<'a> {
point_y: &[u32],
input_size: usize,
num_subtasks: usize,
- num_columns: u32,
- s_num_x_workgroups: usize,
- s_num_y_workgroups: usize,
- s_num_z_workgroups: usize,
+ num_columns: usize,
+ simd_width: usize,
s_workgroup_size: usize,
) -> Result<(Vec, Vec, Vec), Box> {
let mut helper = MetalHelper::with_device(self.shader_manager.device().clone());
@@ -429,10 +440,11 @@ impl<'a> SMVP<'a> {
.get_shader(&ShaderType::SMVP)
.ok_or("SMVP shader not found")?;
+ let half_columns = num_columns / 2;
+
let bucket_size =
- (num_columns / 2) as usize * self.shader_manager.config().num_limbs * 4 * num_subtasks;
+ half_columns as usize * self.shader_manager.config().num_limbs * 4 * num_subtasks;
- // Create buffers
let row_ptr_buf = helper.create_buffer(&csc_col_ptr.to_vec());
let val_idx_buf = helper.create_buffer(&csc_val_idxs.to_vec());
let point_x_buf = helper.create_buffer(&point_x.to_vec());
@@ -445,21 +457,30 @@ impl<'a> SMVP<'a> {
// Execute in chunks
let num_subtask_chunk_size = 4u32;
for offset in (0..num_subtasks as u32).step_by(num_subtask_chunk_size as usize) {
- let params = vec![
+ let remaining_subtasks = (num_subtasks as u32 - offset).min(num_subtask_chunk_size);
+ let valid_threads = half_columns as u64 * remaining_subtasks as u64;
+
+ let max_y = ((valid_threads as usize)
+ / (s_workgroup_size * remaining_subtasks as usize))
+ .max(1);
+ let s_num_y_workgroups = simd_width.min(max_y) as u64;
+ let s_num_z_workgroups = remaining_subtasks as u64;
+
+ let threads_per_grid =
+ (s_workgroup_size as u64) * s_num_y_workgroups * s_num_z_workgroups;
+ let s_num_x_workgroups = (valid_threads + threads_per_grid - 1) / threads_per_grid; // ceil div
+
+ let params_buf = helper.create_buffer(&vec![
input_size as u32,
- s_num_y_workgroups as u32,
+ num_columns as u32,
num_subtasks as u32,
offset,
- ];
- let params_buf = helper.create_buffer(¶ms);
-
- let adjusted_x_workgroups =
- s_num_x_workgroups / (num_subtasks / num_subtask_chunk_size as usize);
+ ]);
let thread_group_count = helper.create_thread_group_size(
- adjusted_x_workgroups as u64,
- s_num_y_workgroups as u64,
- s_num_z_workgroups as u64,
+ s_num_x_workgroups,
+ s_num_y_workgroups,
+ s_num_z_workgroups,
);
let threads_per_threadgroup =
helper.create_thread_group_size(s_workgroup_size as u64, 1, 1);
@@ -507,7 +528,7 @@ impl<'a> PBPR<'a> {
bucket_y: &[u32],
bucket_z: &[u32],
num_subtasks: usize,
- num_columns: u32,
+ num_columns: usize,
num_subtasks_per_bpr_1: usize,
num_subtasks_per_bpr_2: usize,
b_num_x_workgroups: usize,
@@ -542,12 +563,11 @@ impl<'a> PBPR<'a> {
for subtask_chunk_idx in (0..num_subtasks).step_by(num_subtasks_per_bpr_1) {
let params = vec![
subtask_chunk_idx as u32,
- num_columns,
+ num_columns as u32,
num_subtasks_per_bpr_1 as u32,
0u32, // dummy 4 bytes to align the Metal's uint3 16-byte style
];
let params_buf = helper.create_buffer(¶ms);
- let workgroup_size_buf = helper.create_buffer(&vec![b_workgroup_size as u32]);
let stage1_thread_group_count = helper.create_thread_group_size(
b_num_x_workgroups as u64,
@@ -567,7 +587,6 @@ impl<'a> PBPR<'a> {
&g_points_y_buf,
&g_points_z_buf,
¶ms_buf,
- &workgroup_size_buf,
],
&stage1_thread_group_count,
&stage1_threads_per_threadgroup,
@@ -578,12 +597,11 @@ impl<'a> PBPR<'a> {
for subtask_chunk_idx in (0..num_subtasks).step_by(num_subtasks_per_bpr_2) {
let params = vec![
subtask_chunk_idx as u32,
- num_columns,
+ num_columns as u32,
num_subtasks_per_bpr_2 as u32,
0u32, // dummy 4 bytes to align the Metal's uint3 16-byte style
];
let params_buf = helper.create_buffer(¶ms);
- let workgroup_size_buf = helper.create_buffer(&vec![b_workgroup_size as u32]);
let stage2_thread_group_count = helper.create_thread_group_size(
b_2_num_x_workgroups as u64,
@@ -603,7 +621,6 @@ impl<'a> PBPR<'a> {
&g_points_y_buf,
&g_points_z_buf,
¶ms_buf,
- &workgroup_size_buf,
],
&stage2_thread_group_count,
&stage2_threads_per_threadgroup,
@@ -623,8 +640,8 @@ impl<'a> PBPR<'a> {
/// Convenient wrapper that mimics the Arkworks VariableBaseMSM interface
/// Usage: metal_variable_base_msm(&bases, &scalars)
pub fn metal_variable_base_msm(
- bases: &[ark_bn254::G1Affine],
- scalars: &[ark_bn254::Fr],
+ mut bases: &[ark_bn254::G1Affine],
+ mut scalars: &[ark_bn254::Fr],
) -> Result> {
// Handle empty input case
if bases.is_empty() || scalars.is_empty() {
@@ -633,11 +650,48 @@ pub fn metal_variable_base_msm(
// Ensure bases and scalars have the same length
if bases.len() != scalars.len() {
- return Err("Bases and scalars must have the same length".into());
+ let min_len = std::cmp::min(bases.len(), scalars.len());
+ bases = &bases[..min_len];
+ scalars = &scalars[..min_len];
}
+ let input_size = bases.len();
+
+ // window_size is determined by experiment results
+ let window_size = if input_size < 16384 {
+ // 2^14
+ 8
+ } else if input_size < 524288 {
+ // 2^19
+ 13
+ } else if input_size <= 16777216 {
+ // 2^24
+ 15
+ } else {
+ // > 2^24
+ 16
+ };
+
+ // workgroup size will be adjusted based on the scale factor
+ let scale_factor = if input_size <= 4096 {
+ // 2^12
+ 1 // 2^0
+ } else if input_size <= 65536 {
+ // 2^16
+ 2 // 2^1
+ } else if input_size <= 1048576 {
+ // 2^20
+ 1 << 2 // 2^2
+ } else if input_size <= 16777216 {
+ // 2^24
+ 1 << 3 // 2^3
+ } else {
+ // > 2^24
+ 1 << 4 // 2^4
+ };
+
let pipeline = MetalMSMPipeline::with_default_config()?;
- pipeline.execute(bases, scalars)
+ pipeline.execute_pipeline(bases, scalars, input_size, window_size, scale_factor)
}
/// Test utilities module - available for both unit tests and integration tests
@@ -684,7 +738,7 @@ mod tests {
#[test]
fn test_metal_msm_pipeline() {
- let log_input_size = 20;
+ let log_input_size = 16;
let input_size = 1 << log_input_size;
println!("Generating {} elements", input_size);
diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal
index e56b0727..83f54085 100644
--- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal
@@ -4,14 +4,15 @@
using namespace metal;
#include "../misc/get_constant.metal"
-
inline BigIntResult bigint_add_unsafe(
BigInt lhs,
- BigInt rhs
-) {
+ BigInt rhs)
+{
BigIntResult res;
res.carry = 0;
uint mask = (1 << LOG_LIMB_SIZE) - 1;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
uint c = lhs.limbs[i] + rhs.limbs[i] + res.carry;
res.value.limbs[i] = c & mask;
@@ -22,12 +23,14 @@ inline BigIntResult bigint_add_unsafe(
inline BigIntResultWide bigint_add_wide(
BigInt lhs,
- BigInt rhs
-) {
+ BigInt rhs)
+{
BigIntResultWide res;
res.carry = 0;
uint mask = (1 << LOG_LIMB_SIZE) - 1;
uint carry = 0;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
uint c = lhs.limbs[i] + rhs.limbs[i] + carry;
res.value.limbs[i] = c & mask;
@@ -40,10 +43,12 @@ inline BigIntResultWide bigint_add_wide(
inline BigIntResult bigint_sub(
BigInt lhs,
- BigInt rhs
-) {
+ BigInt rhs)
+{
BigIntResult res;
res.carry = 0;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
res.value.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - res.carry;
if (lhs.limbs[i] < rhs.limbs[i] + res.carry) {
@@ -56,13 +61,14 @@ inline BigIntResult bigint_sub(
return res;
}
-
inline BigIntResultWide bigint_sub_wide(
BigIntWide lhs,
- BigIntWide rhs
-) {
+ BigIntWide rhs)
+{
BigIntResultWide res;
res.carry = 0;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
res.value.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - res.carry;
if (lhs.limbs[i] < rhs.limbs[i] + res.carry) {
@@ -77,58 +83,73 @@ inline BigIntResultWide bigint_sub_wide(
inline bool bigint_gte(
BigInt lhs,
- BigInt rhs
-) {
- // for (uint i = NUM_LIMBS-1; i >= 0; i--) is troublesome from unknown reason
+ BigInt rhs)
+{
+#pragma unroll(16)
for (uint idx = 0; idx < NUM_LIMBS; idx++) {
uint i = NUM_LIMBS - 1 - idx;
- if (lhs.limbs[i] < rhs.limbs[i]) return false;
- else if (lhs.limbs[i] > rhs.limbs[i]) return true;
+ if (lhs.limbs[i] < rhs.limbs[i])
+ return false;
+ else if (lhs.limbs[i] > rhs.limbs[i])
+ return true;
}
return true;
}
inline bool bigint_wide_gte(
BigIntWide lhs,
- BigIntWide rhs
-) {
+ BigIntWide rhs)
+{
+#pragma unroll(17)
for (uint idx = 0; idx < NUM_LIMBS_WIDE; idx++) {
uint i = NUM_LIMBS_WIDE - 1 - idx;
- if (lhs.limbs[i] < rhs.limbs[i]) return false;
- else if (lhs.limbs[i] > rhs.limbs[i]) return true;
+ if (lhs.limbs[i] < rhs.limbs[i])
+ return false;
+ else if (lhs.limbs[i] > rhs.limbs[i])
+ return true;
}
return true;
}
inline bool bigint_eq(
BigInt lhs,
- BigInt rhs
-) {
+ BigInt rhs)
+{
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
- if (lhs.limbs[i] != rhs.limbs[i]) return false;
+ if (lhs.limbs[i] != rhs.limbs[i])
+ return false;
}
return true;
}
-inline bool is_bigint_zero(BigInt x) {
+inline bool is_bigint_zero(BigInt x)
+{
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
- if (x.limbs[i] != 0) return false;
+ if (x.limbs[i] != 0)
+ return false;
}
return true;
}
// Conversion functions
-inline BigIntWide bigint_to_wide(BigInt x) {
+inline BigIntWide bigint_to_wide(BigInt x)
+{
BigIntWide res = bigint_zero_wide();
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
res.limbs[i] = x.limbs[i];
}
return res;
}
-inline BigInt bigint_from_wide(BigIntWide x) {
+inline BigInt bigint_from_wide(BigIntWide x)
+{
BigInt res = bigint_zero();
- // ignore the last limb
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
res.limbs[i] = x.limbs[i];
}
@@ -136,18 +157,22 @@ inline BigInt bigint_from_wide(BigIntWide x) {
}
// Overload Operators
-constexpr BigInt operator+(const BigInt lhs, const BigInt rhs) {
+constexpr BigInt operator+(const BigInt lhs, const BigInt rhs)
+{
return bigint_add_unsafe(lhs, rhs).value;
}
-constexpr BigInt operator-(const BigInt lhs, const BigInt rhs) {
+constexpr BigInt operator-(const BigInt lhs, const BigInt rhs)
+{
return bigint_sub(lhs, rhs).value;
}
-constexpr bool operator>=(const BigInt lhs, const BigInt rhs) {
+constexpr bool operator>=(const BigInt lhs, const BigInt rhs)
+{
return bigint_gte(lhs, rhs);
}
-constexpr bool operator==(const BigInt lhs, const BigInt rhs) {
+constexpr bool operator==(const BigInt lhs, const BigInt rhs)
+{
return bigint_eq(lhs, rhs);
}
\ No newline at end of file
diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal
index f38216aa..df12c00b 100644
--- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal
@@ -1,16 +1,16 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "bigint.metal"
+#include
+#include
kernel void test_bigint_add_unsafe(
- device BigInt* a [[ buffer(0) ]],
- device BigInt* b [[ buffer(1) ]],
- device BigInt* res [[ buffer(2) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a [[buffer(0)]],
+ device BigInt* b [[buffer(1)]],
+ device BigInt* res [[buffer(2)]],
+ uint gid [[thread_position_in_grid]])
+{
BigIntResult result = bigint_add_unsafe(*a, *b);
*res = result.value;
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal
index 8b6f5191..fc72ef84 100644
--- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal
@@ -1,16 +1,16 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "bigint.metal"
+#include
+#include
kernel void test_bigint_add_wide(
- device BigInt* a [[ buffer(0) ]],
- device BigInt* b [[ buffer(1) ]],
- device BigIntWide* res [[ buffer(2) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a [[buffer(0)]],
+ device BigInt* b [[buffer(1)]],
+ device BigIntWide* res [[buffer(2)]],
+ uint gid [[thread_position_in_grid]])
+{
BigIntResultWide result = bigint_add_wide(*a, *b);
*res = result.value;
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal
index af029941..4c719b41 100644
--- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal
@@ -1,16 +1,16 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "bigint.metal"
+#include
+#include
kernel void test_bigint_sub(
- device BigInt* a [[ buffer(0) ]],
- device BigInt* b [[ buffer(1) ]],
- device BigInt* res [[ buffer(2) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a [[buffer(0)]],
+ device BigInt* b [[buffer(1)]],
+ device BigInt* res [[buffer(2)]],
+ uint gid [[thread_position_in_grid]])
+{
BigIntResult result = bigint_sub(*a, *b);
*res = result.value;
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/constants.metal b/mopro-msm/src/msm/metal_msm/shader/constants.metal
index 0220c9b5..53af531d 100644
--- a/mopro-msm/src/msm/metal_msm/shader/constants.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/constants.metal
@@ -9,9 +9,6 @@
#define N0 25481
#define NSAFE 1
#define SLACK 2
-#define CHUNK_SIZE 16
-#define NUM_COLUMNS 65536
-#define NUM_SUBTASKS 16
constant uint32_t BARRETT_MU[NUM_LIMBS] = {
37093,
6591,
diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal
index baa7085d..c398f01a 100644
--- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal
@@ -3,12 +3,13 @@
#pragma once
using namespace metal;
-#include
-#include
#include "../mont_backend/mont.metal"
#include "./utils.metal"
+#include
+#include
-inline Jacobian jacobian_dbl_2009_l(Jacobian pt) {
+inline Jacobian jacobian_dbl_2009_l(Jacobian pt)
+{
BigInt x = pt.x;
BigInt y = pt.y;
BigInt z = pt.z;
@@ -42,10 +43,14 @@ inline Jacobian jacobian_dbl_2009_l(Jacobian pt) {
return result;
}
-inline Jacobian jacobian_add_2007_bl(Jacobian a, Jacobian b) {
- if (is_jacobian_zero(a)) return b;
- if (is_jacobian_zero(b)) return a;
- if (a == b) return jacobian_dbl_2009_l(a);
+inline Jacobian jacobian_add_2007_bl(Jacobian a, Jacobian b)
+{
+ if (is_jacobian_zero(a))
+ return b;
+ if (is_jacobian_zero(b))
+ return a;
+ if (a == b)
+ return jacobian_dbl_2009_l(a);
BigInt x1 = a.x;
BigInt y1 = a.y;
@@ -96,7 +101,8 @@ inline Jacobian jacobian_add_2007_bl(Jacobian a, Jacobian b) {
// Notice that this algo only takes standard form instead of Montgomery form
// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-madd-2007-bl
-inline Jacobian jacobian_madd_2007_bl(Jacobian a, Affine b) {
+inline Jacobian jacobian_madd_2007_bl(Jacobian a, Affine b)
+{
BigInt x1 = a.x;
BigInt y1 = a.y;
BigInt z1 = a.z;
@@ -105,53 +111,53 @@ inline Jacobian jacobian_madd_2007_bl(Jacobian a, Affine b) {
// Z1Z1 = Z1^2
BigInt z1z1 = mont_mul_cios(z1, z1);
-
+
// U2 = X2*Z1Z1
BigInt u2 = mont_mul_cios(x2, z1z1);
-
+
// S2 = Y2*Z1*Z1Z1
BigInt temp_s2 = mont_mul_cios(y2, z1);
BigInt s2 = mont_mul_cios(temp_s2, z1z1);
-
+
// H = U2-X1
BigInt h = ff_sub(u2, x1);
-
+
// HH = H^2
BigInt hh = mont_mul_cios(h, h);
-
+
// I = 4*HH
BigInt i = ff_add(hh, hh); // *2
- i = ff_add(i, i); // *4
-
+ i = ff_add(i, i); // *4
+
// J = H*I
BigInt j = mont_mul_cios(h, i);
-
+
// r = 2*(S2-Y1)
BigInt s2_minus_y1 = ff_sub(s2, y1);
BigInt r = ff_add(s2_minus_y1, s2_minus_y1);
-
+
// V = X1*I
BigInt v = mont_mul_cios(x1, i);
-
+
// X3 = r^2-J-2*V
BigInt r2 = mont_mul_cios(r, r);
BigInt v2 = ff_add(v, v);
BigInt jv2 = ff_add(j, v2);
BigInt x3 = ff_sub(r2, jv2);
-
+
// Y3 = r*(V-X3)-2*Y1*J
BigInt v_minus_x3 = ff_sub(v, x3);
BigInt r_vmx3 = mont_mul_cios(r, v_minus_x3);
BigInt y1j = mont_mul_cios(y1, j);
BigInt y1j2 = ff_add(y1j, y1j);
BigInt y3 = ff_sub(r_vmx3, y1j2);
-
+
// Z3 = (Z1+H)^2-Z1Z1-HH
BigInt z1_plus_h = ff_add(z1, h);
BigInt z1_plus_h_squared = mont_mul_cios(z1_plus_h, z1_plus_h);
BigInt temp = ff_sub(z1_plus_h_squared, z1z1);
BigInt z3 = ff_sub(temp, hh);
-
+
Jacobian result;
result.x = x3;
result.y = y3;
@@ -161,8 +167,8 @@ inline Jacobian jacobian_madd_2007_bl(Jacobian a, Affine b) {
inline Jacobian jacobian_scalar_mul(
Jacobian pt,
- uint scalar
-) {
+ uint scalar)
+{
// Handle special cases first
if (scalar == 0 || is_bigint_zero(pt.z)) {
return get_bn254_zero_mont();
@@ -182,12 +188,15 @@ inline Jacobian jacobian_scalar_mul(
temp = jacobian_dbl_2009_l(temp);
s = s >> 1;
}
-
+
return result;
}
-inline Jacobian jacobian_neg(Jacobian pt) {
- if (is_jacobian_zero(pt)) { return pt; }
+inline Jacobian jacobian_neg(Jacobian pt)
+{
+ if (is_jacobian_zero(pt)) {
+ return pt;
+ }
// Negate Y (mod p): newY = p - Y
BigInt p = MODULUS;
@@ -201,14 +210,17 @@ inline Jacobian jacobian_neg(Jacobian pt) {
}
// Override operators in Jacobian
-constexpr Jacobian operator+(const Jacobian lhs, const Jacobian rhs) {
+constexpr Jacobian operator+(const Jacobian lhs, const Jacobian rhs)
+{
return jacobian_add_2007_bl(lhs, rhs);
}
-constexpr Jacobian operator+(const Jacobian lhs, const Affine rhs) {
+constexpr Jacobian operator+(const Jacobian lhs, const Affine rhs)
+{
return jacobian_madd_2007_bl(lhs, rhs);
}
-constexpr Jacobian operator-(const Jacobian pt) {
+constexpr Jacobian operator-(const Jacobian pt)
+{
return jacobian_neg(pt);
}
\ No newline at end of file
diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal
index 8193dc18..b2e0c79c 100644
--- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal
@@ -1,22 +1,22 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "jacobian.metal"
+#include
+#include
kernel void test_jacobian_add_2007_bl(
- device BigInt* a_xr [[ buffer(0) ]],
- device BigInt* a_yr [[ buffer(1) ]],
- device BigInt* a_zr [[ buffer(2) ]],
- device BigInt* b_xr [[ buffer(3) ]],
- device BigInt* b_yr [[ buffer(4) ]],
- device BigInt* b_zr [[ buffer(5) ]],
- device BigInt* result_xr [[ buffer(6) ]],
- device BigInt* result_yr [[ buffer(7) ]],
- device BigInt* result_zr [[ buffer(8) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a_xr [[buffer(0)]],
+ device BigInt* a_yr [[buffer(1)]],
+ device BigInt* a_zr [[buffer(2)]],
+ device BigInt* b_xr [[buffer(3)]],
+ device BigInt* b_yr [[buffer(4)]],
+ device BigInt* b_zr [[buffer(5)]],
+ device BigInt* result_xr [[buffer(6)]],
+ device BigInt* result_yr [[buffer(7)]],
+ device BigInt* result_zr [[buffer(8)]],
+ uint gid [[thread_position_in_grid]])
+{
BigInt x1 = *a_xr;
BigInt y1 = *a_yr;
BigInt z1 = *a_zr;
@@ -24,8 +24,14 @@ kernel void test_jacobian_add_2007_bl(
BigInt y2 = *b_yr;
BigInt z2 = *b_zr;
- Jacobian a; a.x = x1; a.y = y1; a.z = z1;
- Jacobian b; b.x = x2; b.y = y2; b.z = z2;
+ Jacobian a;
+ a.x = x1;
+ a.y = y1;
+ a.z = z1;
+ Jacobian b;
+ b.x = x2;
+ b.y = y2;
+ b.z = z2;
Jacobian res = jacobian_add_2007_bl(a, b);
*result_xr = res.x;
diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal
index cca7f899..97f914bb 100644
--- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal
@@ -1,24 +1,27 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "jacobian.metal"
+#include
+#include
kernel void test_jacobian_dbl_2009_l(
- device BigInt* a_xr [[ buffer(0) ]],
- device BigInt* a_yr [[ buffer(1) ]],
- device BigInt* a_zr [[ buffer(2) ]],
- device BigInt* result_xr [[ buffer(3) ]],
- device BigInt* result_yr [[ buffer(4) ]],
- device BigInt* result_zr [[ buffer(5) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a_xr [[buffer(0)]],
+ device BigInt* a_yr [[buffer(1)]],
+ device BigInt* a_zr [[buffer(2)]],
+ device BigInt* result_xr [[buffer(3)]],
+ device BigInt* result_yr [[buffer(4)]],
+ device BigInt* result_zr [[buffer(5)]],
+ uint gid [[thread_position_in_grid]])
+{
BigInt x1 = *a_xr;
BigInt y1 = *a_yr;
BigInt z1 = *a_zr;
- Jacobian a; a.x = x1; a.y = y1; a.z = z1;
+ Jacobian a;
+ a.x = x1;
+ a.y = y1;
+ a.z = z1;
Jacobian res = jacobian_dbl_2009_l(a);
*result_xr = res.x;
diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal
index b83ccaa7..3cb4805f 100644
--- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal
@@ -1,29 +1,34 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "jacobian.metal"
+#include
+#include
kernel void test_jacobian_madd_2007_bl(
- device BigInt* a_xr [[ buffer(0) ]],
- device BigInt* a_yr [[ buffer(1) ]],
- device BigInt* a_zr [[ buffer(2) ]],
- device BigInt* b_xr [[ buffer(3) ]],
- device BigInt* b_yr [[ buffer(4) ]],
- device BigInt* result_xr [[ buffer(5) ]],
- device BigInt* result_yr [[ buffer(6) ]],
- device BigInt* result_zr [[ buffer(7) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a_xr [[buffer(0)]],
+ device BigInt* a_yr [[buffer(1)]],
+ device BigInt* a_zr [[buffer(2)]],
+ device BigInt* b_xr [[buffer(3)]],
+ device BigInt* b_yr [[buffer(4)]],
+ device BigInt* result_xr [[buffer(5)]],
+ device BigInt* result_yr [[buffer(6)]],
+ device BigInt* result_zr [[buffer(7)]],
+ uint gid [[thread_position_in_grid]])
+{
BigInt x1 = *a_xr;
BigInt y1 = *a_yr;
BigInt z1 = *a_zr;
BigInt x2 = *b_xr;
BigInt y2 = *b_yr;
- Jacobian a; a.x = x1; a.y = y1; a.z = z1;
- Affine b; b.x = x2; b.y = y2;
+ Jacobian a;
+ a.x = x1;
+ a.y = y1;
+ a.z = z1;
+ Affine b;
+ b.x = x2;
+ b.y = y2;
Jacobian res = jacobian_madd_2007_bl(a, b);
*result_xr = res.x;
diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_neg.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_neg.metal
index bb1e2037..89e6ec4a 100644
--- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_neg.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_neg.metal
@@ -1,22 +1,25 @@
using namespace metal;
-#include
-#include
#include "jacobian.metal"
+#include
+#include
kernel void test_jacobian_neg(
- device BigInt* a_xr [[ buffer(0) ]],
- device BigInt* a_yr [[ buffer(1) ]],
- device BigInt* a_zr [[ buffer(2) ]],
- device BigInt* result_xr [[ buffer(3) ]],
- device BigInt* result_yr [[ buffer(4) ]],
- device BigInt* result_zr [[ buffer(5) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a_xr [[buffer(0)]],
+ device BigInt* a_yr [[buffer(1)]],
+ device BigInt* a_zr [[buffer(2)]],
+ device BigInt* result_xr [[buffer(3)]],
+ device BigInt* result_yr [[buffer(4)]],
+ device BigInt* result_zr [[buffer(5)]],
+ uint gid [[thread_position_in_grid]])
+{
BigInt x1 = *a_xr;
BigInt y1 = *a_yr;
BigInt z1 = *a_zr;
- Jacobian a; a.x = x1; a.y = y1; a.z = z1;
+ Jacobian a;
+ a.x = x1;
+ a.y = y1;
+ a.z = z1;
Jacobian res = jacobian_neg(a);
*result_xr = res.x;
diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_scalar_mul.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_scalar_mul.metal
index ff41d880..d81dd752 100644
--- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_scalar_mul.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_scalar_mul.metal
@@ -1,14 +1,14 @@
using namespace metal;
-#include
-#include
#include "jacobian.metal"
+#include
+#include
kernel void test_jacobian_scalar_mul(
- device Jacobian& a [[ buffer(0) ]],
- device uint* scalar [[ buffer(1) ]],
- device Jacobian& result [[ buffer(2) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device Jacobian& a [[buffer(0)]],
+ device uint* scalar [[buffer(1)]],
+ device Jacobian& result [[buffer(2)]],
+ uint gid [[thread_position_in_grid]])
+{
uint s = *scalar;
result = jacobian_scalar_mul(a, s);
}
\ No newline at end of file
diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/utils.metal b/mopro-msm/src/msm/metal_msm/shader/curve/utils.metal
index 813a42ca..694d6312 100644
--- a/mopro-msm/src/msm/metal_msm/shader/curve/utils.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/curve/utils.metal
@@ -1,24 +1,31 @@
#pragma once
using namespace metal;
-#include
-#include
#include "../bigint/bigint.metal"
#include "../misc/get_constant.metal"
+#include
+#include
-inline bool is_jacobian_zero(Jacobian a) {
+inline bool is_jacobian_zero(Jacobian a)
+{
return is_bigint_zero(a.z);
}
-inline bool jacobian_eq(Jacobian lhs, Jacobian rhs) {
+inline bool jacobian_eq(Jacobian lhs, Jacobian rhs)
+{
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
- if (lhs.x.limbs[i] != rhs.x.limbs[i]) return false;
- else if (lhs.y.limbs[i] != rhs.y.limbs[i]) return false;
- else if (lhs.z.limbs[i] != rhs.z.limbs[i]) return false;
+ if (lhs.x.limbs[i] != rhs.x.limbs[i])
+ return false;
+ else if (lhs.y.limbs[i] != rhs.y.limbs[i])
+ return false;
+ else if (lhs.z.limbs[i] != rhs.z.limbs[i])
+ return false;
}
return true;
}
-constexpr bool operator==(const Jacobian lhs, const Jacobian rhs) {
+constexpr bool operator==(const Jacobian lhs, const Jacobian rhs)
+{
return jacobian_eq(lhs, rhs);
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/barrett_reduction.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/barrett_reduction.metal
index 7a72f685..10fb82e1 100644
--- a/mopro-msm/src/msm/metal_msm/shader/cuzk/barrett_reduction.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/barrett_reduction.metal
@@ -1,60 +1,67 @@
#pragma once
-#include
-#include
#include "../field/ff.metal"
+#include
+#include
using namespace metal;
#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320)
- #include
- constant os_log logger(/*subsystem=*/"barret_reduction", /*category=*/"metal");
- #define LOG_DEBUG_DUPL(...) logger.log_debug(__VA_ARGS__)
+#include
+constant os_log logger(/*subsystem=*/"barret_reduction", /*category=*/"metal");
+#define LOG_DEBUG_DUPL(...) logger.log_debug(__VA_ARGS__)
#else
- #define LOG_DEBUG_DUPL(...) ((void)0)
+#define LOG_DEBUG_DUPL(...) ((void)0)
#endif
-BigIntExtraWide mul(BigIntWide a, BigIntWide b) {
+BigIntExtraWide mul(BigIntWide a, BigIntWide b)
+{
BigIntExtraWide res = bigint_zero_extra_wide();
-
+
+#pragma unroll(17)
for (uint i = 0; i < NUM_LIMBS_WIDE; i++) {
+#pragma unroll(17)
for (uint j = 0; j < NUM_LIMBS_WIDE; j++) {
ulong c = (ulong)a.limbs[i] * (ulong)b.limbs[j];
- res.limbs[i+j] += c & MASK;
- res.limbs[i+j+1] += c >> LOG_LIMB_SIZE;
+ res.limbs[i + j] += c & MASK;
+ res.limbs[i + j + 1] += c >> LOG_LIMB_SIZE;
}
}
- // Start from 0 and carry the extra over to the next index.
+// Start from 0 and carry the extra over to the next index.
+#pragma unroll(32)
for (uint i = 0; i < NUM_LIMBS_EXTRA_WIDE; i++) {
- res.limbs[i+1] += res.limbs[i] >> LOG_LIMB_SIZE;
+ res.limbs[i + 1] += res.limbs[i] >> LOG_LIMB_SIZE;
res.limbs[i] = res.limbs[i] & MASK;
}
return res;
}
-BigIntResultExtraWide sub_512(BigIntExtraWide a, BigIntExtraWide b) {
+BigIntResultExtraWide sub_512(BigIntExtraWide a, BigIntExtraWide b)
+{
BigIntResultExtraWide res;
res.value = bigint_zero_extra_wide();
res.carry = 0;
+#pragma unroll(32)
for (uint i = 0; i < NUM_LIMBS_EXTRA_WIDE; i++) {
res.value.limbs[i] = a.limbs[i] - b.limbs[i] - res.carry;
if (a.limbs[i] < (b.limbs[i] + res.carry)) {
res.value.limbs[i] += (MASK + 1);
res.carry = 1;
- }
- else {
+ } else {
res.carry = 0;
}
}
return res;
}
-BigIntResultExtraWide add_512(BigIntExtraWide a, BigIntExtraWide b) {
+BigIntResultExtraWide add_512(BigIntExtraWide a, BigIntExtraWide b)
+{
BigIntResultExtraWide res;
res.value = bigint_zero_extra_wide();
res.carry = 0;
+#pragma unroll(32)
for (uint i = 0; i < NUM_LIMBS_EXTRA_WIDE; i++) {
ulong sum = (ulong)a.limbs[i] + (ulong)b.limbs[i] + res.carry;
res.value.limbs[i] = sum & MASK;
@@ -63,20 +70,19 @@ BigIntResultExtraWide add_512(BigIntExtraWide a, BigIntExtraWide b) {
return res;
}
-BigInt get_higher_with_slack(BigIntExtraWide a) {
+BigInt get_higher_with_slack(BigIntExtraWide a)
+{
BigInt out = bigint_zero();
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
- out.limbs[i] = ((a.limbs[i + NUM_LIMBS] << SLACK) +
- (a.limbs[i + NUM_LIMBS - 1] >> (LOG_LIMB_SIZE - SLACK))) & MASK;
+ out.limbs[i] = ((a.limbs[i + NUM_LIMBS] << SLACK) + (a.limbs[i + NUM_LIMBS - 1] >> (LOG_LIMB_SIZE - SLACK))) & MASK;
}
return out;
}
-BigInt barrett_reduce(BigIntExtraWide a) {
- for (uint i = 0; i < NUM_LIMBS; i++) {
- LOG_DEBUG_DUPL("res.limbs[%u] = %u", i, a.limbs[i]);
- }
-
+BigInt barrett_reduce(BigIntExtraWide a)
+{
BigInt p = MODULUS;
BigInt mu = get_mu();
@@ -96,6 +102,8 @@ BigInt barrett_reduce(BigIntExtraWide a) {
}
BigInt r = bigint_zero();
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
r.limbs[i] = r_wide.limbs[i];
}
@@ -103,7 +111,8 @@ BigInt barrett_reduce(BigIntExtraWide a) {
return ff_reduce(r);
}
-inline BigInt field_mul(BigIntWide a, BigIntWide b) {
+inline BigInt field_mul(BigIntWide a, BigIntWide b)
+{
BigIntExtraWide xy = mul(a, b);
return barrett_reduce(xy);
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/convert_point_coords_and_decompose_scalars.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/convert_point_coords_and_decompose_scalars.metal
index 430338b9..7da2735b 100644
--- a/mopro-msm/src/msm/metal_msm/shader/cuzk/convert_point_coords_and_decompose_scalars.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/convert_point_coords_and_decompose_scalars.metal
@@ -1,34 +1,39 @@
-#include
-#include
+#include "../misc/get_constant.metal"
#include "barrett_reduction.metal"
#include "extract_word_from_bytes_le.metal"
-#include "../misc/get_constant.metal"
+#include
+#include
using namespace metal;
#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320)
- #include
- constant os_log logger_kernel(/*subsystem=*/"pt_conversion", /*category=*/"metal");
- #define LOG_DEBUG(...) logger_kernel.log_debug(__VA_ARGS__)
+#include
+constant os_log logger_kernel(/*subsystem=*/"pt_conversion", /*category=*/"metal");
+#define LOG_DEBUG(...) logger_kernel.log_debug(__VA_ARGS__)
#else
- #define LOG_DEBUG(...) ((void)0)
+#define LOG_DEBUG(...) ((void)0)
#endif
kernel void convert_point_coords_and_decompose_scalars(
- device const uint* coords [[buffer(0)]],
- device const uint* scalars [[buffer(1)]],
- constant uint& input_size [[buffer(2)]],
- device BigInt* point_x [[buffer(3)]],
- device BigInt* point_y [[buffer(4)]],
- device uint* chunks [[buffer(5)]],
- device const uint* num_y_workgroups [[buffer(6)]],
- uint3 gid [[thread_position_in_grid]]
-) {
+ device const uint* coords [[buffer(0), access(read)]],
+ device const uint* scalars [[buffer(1), access(read)]],
+ device BigInt* point_x [[buffer(2), access(write)]],
+ device BigInt* point_y [[buffer(3), access(write)]],
+ device uint* chunks [[buffer(4), access(write)]],
+ constant uint4& params [[buffer(5), access(read)]],
+ uint3 gid [[thread_position_in_grid]],
+ uint3 threadgroup_size [[threadgroups_per_grid]])
+{
+ const uint input_size = params[0];
+ const uint window_size = params[1];
+ const uint num_columns = params[2];
+ const uint num_subtask = params[3];
+
uint gidx = gid.x;
uint gidy = gid.y;
- uint id = gidx * (*num_y_workgroups) + gidy;
+ uint id = gidx * threadgroup_size.y + gidy;
// 1) Convert coords to BigInt in Montgomery form
- // We read 16 halfwords for x and 16 halfwords for y.
+ // We read 16 halfwords for x and 16 halfwords for y.
uint x_bytes[16];
uint y_bytes[16];
@@ -37,22 +42,24 @@ kernel void convert_point_coords_and_decompose_scalars(
// so each point uses 16 x 32-bit = 16 indices.
uint base_offset = id * 16u;
+#pragma unroll(8)
for (uint i = 0u; i < 8u; i++) {
// coords[base_offset + i] is the i-th 32-bit chunk of x
// coords[base_offset + 8 + i] is the i-th 32-bit chunk of y
uint x_val = coords[base_offset + i];
uint y_val = coords[base_offset + 8u + i];
- x_bytes[15 - (i * 2)] = x_val & 0xFFFFu;
+ x_bytes[15 - (i * 2)] = x_val & 0xFFFFu;
x_bytes[15 - (i * 2) - 1] = x_val >> 16u;
- y_bytes[15 - (i * 2)] = y_val & 0xFFFFu;
+ y_bytes[15 - (i * 2)] = y_val & 0xFFFFu;
y_bytes[15 - (i * 2) - 1] = y_val >> 16u;
}
BigInt x_bigint = bigint_zero();
BigInt y_bigint = bigint_zero();
+#pragma unroll(15)
for (uint i = 0; i < NUM_LIMBS - 1u; i++) {
x_bigint.limbs[i] = extract_word_from_bytes_le(x_bytes, i, LOG_LIMB_SIZE);
y_bigint.limbs[i] = extract_word_from_bytes_le(y_bytes, i, LOG_LIMB_SIZE);
@@ -74,6 +81,8 @@ kernel void convert_point_coords_and_decompose_scalars(
// 2) Decompose scalars: read 8 32-bit values => 16 halfwords in `scalar_bytes`.
uint scalar_bytes[16];
+
+#pragma unroll(8)
for (uint i = 0; i < 8u; i++) {
uint s = scalars[id * 8u + i];
uint hi = s >> 16u;
@@ -82,38 +91,32 @@ kernel void convert_point_coords_and_decompose_scalars(
scalar_bytes[15u - (i * 2u) - 1u] = hi;
}
- // Extract wNAF representation. each chunk is CHUNK_SIZE bits from the scalar.
- uint chunks_arr[NUM_SUBTASKS];
- for (uint i = 0; i < NUM_SUBTASKS - 1u; i++) {
- chunks_arr[i] = extract_word_from_bytes_le(scalar_bytes, i, CHUNK_SIZE);
- }
-
- // The last chunk is special if (NUM_SUBTASKS * CHUNK_SIZE > 256)
- chunks_arr[NUM_SUBTASKS - 1] =
- scalar_bytes[0] >> (((NUM_SUBTASKS * CHUNK_SIZE - 256u) + 16u) - CHUNK_SIZE);
-
- // 3) Signed wNAF in the range [−(l−1), (l−1)]
- uint l = NUM_COLUMNS;
+ // Extract wNAF representation. each chunk is window_size bits from the scalar.
+ uint l = num_columns;
uint s = l / 2u;
-
- int signed_slices[NUM_SUBTASKS];
uint carry = 0;
- for (uint i = 0; i < NUM_SUBTASKS; i++) {
- int slice_val = int(chunks_arr[i] + carry);
+
+ for (uint i = 0; i < num_subtask; i++) {
+ // Extract chunk on-demand
+ uint chunk_val;
+ if (i < num_subtask - 1u) {
+ chunk_val = extract_word_from_bytes_le(scalar_bytes, i, window_size);
+ } else {
+ // The last chunk is special
+ chunk_val = scalar_bytes[0] >> (((num_subtask * window_size - 256u) + 16u) - window_size);
+ }
+
+ // Process signed wNAF directly
+ int slice_val = int(chunk_val + carry);
if (slice_val >= int(s)) {
slice_val = (int(l) - slice_val) * (-1);
carry = 1u;
- }
- else {
+ } else {
carry = 0u;
}
- signed_slices[i] = slice_val;
- }
- // Store final values into chunks
- for (uint i = 0; i < NUM_SUBTASKS; i++) {
+ // Store final value directly
uint offset = i * input_size;
- // shift negative slices by +s to keep them in [0, l) range
- chunks[id + offset] = uint(signed_slices[i]) + s;
+ chunks[id + offset] = uint(slice_val) + s;
}
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/extract_word_from_bytes_le.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/extract_word_from_bytes_le.metal
index f32d7b74..76b393b6 100644
--- a/mopro-msm/src/msm/metal_msm/shader/cuzk/extract_word_from_bytes_le.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/extract_word_from_bytes_le.metal
@@ -1,32 +1,31 @@
#pragma once
-#include
#include
+#include
using namespace metal;
inline uint32_t extract_word_from_bytes_le(
const thread uint32_t* input,
uint32_t word_idx,
- uint32_t chunk_size
-) {
+ uint32_t window_size)
+{
uint32_t word = 0;
- const uint32_t start_byte_idx = 15 - ((word_idx * chunk_size + chunk_size) / 16);
- const uint32_t end_byte_idx = 15 - ((word_idx * chunk_size) / 16);
-
- const uint32_t start_byte_offset = (word_idx * chunk_size + chunk_size) % 16;
- const uint32_t end_byte_offset = (word_idx * chunk_size) % 16;
-
+ const uint32_t start_byte_idx = 15 - ((word_idx * window_size + window_size) / 16);
+ const uint32_t end_byte_idx = 15 - ((word_idx * window_size) / 16);
+
+ const uint32_t start_byte_offset = (word_idx * window_size + window_size) % 16;
+ const uint32_t end_byte_offset = (word_idx * window_size) % 16;
+
uint32_t mask = 0;
if (start_byte_offset > 0) {
mask = (2 << (start_byte_offset - 1)) - 1;
}
if (start_byte_idx == end_byte_idx) {
word = (input[start_byte_idx] & mask) >> end_byte_offset;
- }
- else {
+ } else {
word = (input[start_byte_idx] & mask) << (16 - end_byte_offset);
word += input[end_byte_idx] >> end_byte_offset;
}
-
+
return word;
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/kernel_barrett_reduction.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/kernel_barrett_reduction.metal
index 1cb44833..86c65d4d 100644
--- a/mopro-msm/src/msm/metal_msm/shader/cuzk/kernel_barrett_reduction.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/kernel_barrett_reduction.metal
@@ -1,21 +1,21 @@
-#include
-#include
#include "barrett_reduction.metal"
+#include
+#include
using namespace metal;
#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320)
- #include
- constant os_log barrett_reduction_logger_kernel(/*subsystem=*/"barrett_reduction", /*category=*/"metal");
- #define LOG_DEBUG(...) barrett_reduction_logger_kernel.log_debug(__VA_ARGS__)
+#include
+constant os_log barrett_reduction_logger_kernel(/*subsystem=*/"barrett_reduction", /*category=*/"metal");
+#define LOG_DEBUG(...) barrett_reduction_logger_kernel.log_debug(__VA_ARGS__)
#else
- #define LOG_DEBUG(...) ((void)0)
+#define LOG_DEBUG(...) ((void)0)
#endif
kernel void test_barrett_reduction(
- device BigIntExtraWide* a [[ buffer(0) ]],
- device BigInt* res [[ buffer(1) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigIntExtraWide* a [[buffer(0)]],
+ device BigInt* res [[buffer(1)]],
+ uint gid [[thread_position_in_grid]])
+{
*res = barrett_reduce(*a);
LOG_DEBUG("pointer: %p", res);
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/kernel_field_mul.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/kernel_field_mul.metal
index 312fe165..47ca0509 100644
--- a/mopro-msm/src/msm/metal_msm/shader/cuzk/kernel_field_mul.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/kernel_field_mul.metal
@@ -1,13 +1,13 @@
-#include
-#include
#include "barrett_reduction.metal"
+#include
+#include
using namespace metal;
kernel void test_field_mul(
- device BigIntWide* a [[ buffer(0) ]],
- device BigIntWide* b [[ buffer(1) ]],
- device BigInt* res [[ buffer(2) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigIntWide* a [[buffer(0)]],
+ device BigIntWide* b [[buffer(1)]],
+ device BigInt* res [[buffer(2)]],
+ uint gid [[thread_position_in_grid]])
+{
*res = field_mul(*a, *b);
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/pbpr.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/pbpr.metal
index c77fa03b..650ca175 100644
--- a/mopro-msm/src/msm/metal_msm/shader/cuzk/pbpr.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/pbpr.metal
@@ -1,19 +1,20 @@
-#include
#include "../curve/jacobian.metal"
#include "../misc/get_constant.metal"
+#include
using namespace metal;
#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320)
- #include
- constant os_log pbpr_logger_kernel(/*subsystem=*/"pbpr", /*category=*/"metal");
- #define LOG_DEBUG(...) pbpr_logger_kernel.log_debug(__VA_ARGS__)
+#include
+constant os_log pbpr_logger_kernel(/*subsystem=*/"pbpr", /*category=*/"metal");
+#define LOG_DEBUG(...) pbpr_logger_kernel.log_debug(__VA_ARGS__)
#else
- #define LOG_DEBUG(...) ((void)0)
+#define LOG_DEBUG(...) ((void)0)
#endif
// This double-and-add code is adapted from the ZPrize test harness:
// https://github.com/demox-labs/webgpu-msm/blob/main/src/reference/webgpu/wgsl/Curve.ts#L78.
-static Jacobian double_and_add(Jacobian point, uint scalar) {
+static Jacobian double_and_add(Jacobian point, uint scalar)
+{
Jacobian result = get_bn254_zero_mont(); // Point at infinity
uint s = scalar;
@@ -30,22 +31,22 @@ static Jacobian double_and_add(Jacobian point, uint scalar) {
}
kernel void bpr_stage_1(
- device BigInt* bucket_sum_x [[ buffer(0) ]],
- device BigInt* bucket_sum_y [[ buffer(1) ]],
- device BigInt* bucket_sum_z [[ buffer(2) ]],
- device BigInt* g_points_x [[ buffer(3) ]],
- device BigInt* g_points_y [[ buffer(4) ]],
- device BigInt* g_points_z [[ buffer(5) ]],
- constant uint3& params [[ buffer(6) ]],
- constant uint& workgroup_size [[ buffer(7) ]],
- uint tid [[ thread_position_in_grid ]])
+ device BigInt* bucket_sum_x [[buffer(0), access(read_write)]],
+ device BigInt* bucket_sum_y [[buffer(1), access(read_write)]],
+ device BigInt* bucket_sum_z [[buffer(2), access(read_write)]],
+ device BigInt* g_points_x [[buffer(3), access(write)]],
+ device BigInt* g_points_y [[buffer(4), access(write)]],
+ device BigInt* g_points_z [[buffer(5), access(write)]],
+ constant uint3& params [[buffer(6), access(read)]],
+ uint tid [[thread_position_in_grid]],
+ uint workgroup_size [[dispatch_threads_per_threadgroup]])
{
const uint thread_id = tid;
const uint num_threads_per_subtask = workgroup_size;
const uint subtask_idx = params[0];
const uint num_columns = params[1];
- const uint num_subtasks_per_bpr = params[2]; // Number of subtasks per shader invocation
+ const uint num_subtasks_per_bpr = params[2]; // Number of subtasks per shader invocation
const uint num_buckets_per_subtask = num_columns / 2u;
const uint total_buckets = num_buckets_per_subtask * num_subtasks_per_bpr;
@@ -61,7 +62,9 @@ kernel void bpr_stage_1(
idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * buckets_per_thread + offset;
}
// guard bucket bounds
- if (idx >= total_buckets) { return; }
+ if (idx >= total_buckets) {
+ return;
+ }
Jacobian m = {
.x = bucket_sum_x[idx],
@@ -71,8 +74,6 @@ kernel void bpr_stage_1(
Jacobian g = m;
for (uint i = 0; i < buckets_per_thread - 1u; i++) {
- LOG_DEBUG("[bpr_stage_1] i: %u, thread_id: %u, num_threads_per_subtask: %u, buckets_per_thread: %u, offset: %u", i, thread_id, num_threads_per_subtask, buckets_per_thread, offset);
-
uint idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * buckets_per_thread - 1u - i;
uint bi = offset + idx;
Jacobian b = {
@@ -89,28 +90,21 @@ kernel void bpr_stage_1(
bucket_sum_z[idx] = m.z;
uint g_rw_idx = (subtask_idx / num_subtasks_per_bpr) * (num_threads_per_subtask * num_subtasks_per_bpr) + thread_id;
- // guard write into g_points buffers
- if (g_rw_idx < num_threads_per_subtask * num_subtasks_per_bpr) {
- g_points_x[g_rw_idx] = g.x;
- g_points_y[g_rw_idx] = g.y;
- g_points_z[g_rw_idx] = g.z;
- }
g_points_x[g_rw_idx] = g.x;
g_points_y[g_rw_idx] = g.y;
g_points_z[g_rw_idx] = g.z;
}
-
kernel void bpr_stage_2(
- device BigInt* bucket_sum_x [[ buffer(0) ]],
- device BigInt* bucket_sum_y [[ buffer(1) ]],
- device BigInt* bucket_sum_z [[ buffer(2) ]],
- device BigInt* g_points_x [[ buffer(3) ]],
- device BigInt* g_points_y [[ buffer(4) ]],
- device BigInt* g_points_z [[ buffer(5) ]],
- constant uint3& params [[ buffer(6) ]],
- constant uint& workgroup_size [[ buffer(7) ]],
- uint tid [[ thread_position_in_grid ]])
+ device BigInt* bucket_sum_x [[buffer(0), access(read)]],
+ device BigInt* bucket_sum_y [[buffer(1), access(read)]],
+ device BigInt* bucket_sum_z [[buffer(2), access(read)]],
+ device BigInt* g_points_x [[buffer(3), access(read_write)]],
+ device BigInt* g_points_y [[buffer(4), access(read_write)]],
+ device BigInt* g_points_z [[buffer(5), access(read_write)]],
+ constant uint3& params [[buffer(6), access(read)]],
+ uint tid [[thread_position_in_grid]],
+ uint workgroup_size [[dispatch_threads_per_threadgroup]])
{
const uint thread_id = tid;
const uint num_threads_per_subtask = workgroup_size;
@@ -136,7 +130,7 @@ kernel void bpr_stage_2(
.y = bucket_sum_y[idx],
.z = bucket_sum_z[idx]
};
-
+
const uint g_rw_idx = (subtask_idx / num_subtasks_per_bpr) * (num_threads_per_subtask * num_subtasks_per_bpr) + thread_id;
Jacobian g = {
.x = g_points_x[g_rw_idx],
diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/smvp.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/smvp.metal
index c557208b..6bc4d0e3 100644
--- a/mopro-msm/src/msm/metal_msm/shader/cuzk/smvp.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/smvp.metal
@@ -1,40 +1,38 @@
-#include
-#include "barrett_reduction.metal"
#include "../curve/jacobian.metal"
+#include "barrett_reduction.metal"
+#include
using namespace metal;
#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320)
- #include
- constant os_log smvp_logger_kernel(/*subsystem=*/"smvp", /*category=*/"metal");
- #define LOG_DEBUG(...) smvp_logger_kernel.log_debug(__VA_ARGS__)
+#include
+constant os_log smvp_logger_kernel(/*subsystem=*/"smvp", /*category=*/"metal");
+#define LOG_DEBUG(...) smvp_logger_kernel.log_debug(__VA_ARGS__)
#else
- #define LOG_DEBUG(...) ((void)0)
+#define LOG_DEBUG(...) ((void)0)
#endif
kernel void smvp(
- device const uint* row_ptr [[ buffer(0) ]],
- device const uint* val_idx [[ buffer(1) ]],
- device const BigInt* new_point_x [[ buffer(2) ]],
- device const BigInt* new_point_y [[ buffer(3) ]],
- device BigInt* bucket_x [[ buffer(4) ]],
- device BigInt* bucket_y [[ buffer(5) ]],
- device BigInt* bucket_z [[ buffer(6) ]],
- constant uint4& params [[ buffer(7) ]],
- uint3 gid [[thread_position_in_grid]],
- uint3 tid [[thread_position_in_threadgroup]]
-) {
- const uint input_size = params[0];
- const uint num_y_workgroups = params[1];
- const uint num_z_workgroups = params[2];
- const uint subtask_offset = params[3];
+ device const uint* row_ptr [[buffer(0), access(read)]],
+ device const uint* val_idx [[buffer(1), access(read)]],
+ device const BigInt* new_point_x [[buffer(2), access(read)]],
+ device const BigInt* new_point_y [[buffer(3), access(read)]],
+ device BigInt* bucket_x [[buffer(4), access(read_write)]],
+ device BigInt* bucket_y [[buffer(5), access(read_write)]],
+ device BigInt* bucket_z [[buffer(6), access(read_write)]],
+ constant uint4& params [[buffer(7), access(read)]],
+ uint3 tgid [[threadgroup_position_in_grid]],
+ uint3 tid [[thread_position_in_threadgroup]],
+ uint3 workgroup_size [[dispatch_threads_per_threadgroup]],
+ uint3 threadgroup_size [[threadgroups_per_grid]])
+{
+ const uint group_id = (tgid.x * threadgroup_size.y + tgid.y) * threadgroup_size.z + tgid.z;
+ const uint id = group_id * workgroup_size.x + tid.x;
- const uint gidx = gid.x;
- const uint gidy = gid.y;
- const uint gidz = gid.z;
+ const uint input_size = params[0];
+ const uint num_columns = params[1];
+ const uint num_subtasks = params[2];
+ const uint subtask_offset = params[3];
- const uint id = (gidx * num_y_workgroups + gidy) * num_z_workgroups + gidz;
-
- const uint num_columns = NUM_COLUMNS;
const uint half_columns = num_columns / 2;
const uint subtask_idx = id / half_columns;
@@ -59,8 +57,11 @@ kernel void smvp(
// Accumulate all the points for that bucket
Jacobian sum = inf;
+
for (uint k = row_begin; k < row_end; k++) {
- const uint idx = val_idx[ (subtask_idx + subtask_offset) * input_size + k ];
+ const uint val_idx_offset = (subtask_idx + subtask_offset) * input_size + k;
+ const uint idx = val_idx[val_idx_offset];
+
Jacobian b = {
.x = new_point_x[idx],
.y = new_point_y[idx],
@@ -84,6 +85,9 @@ kernel void smvp(
// Store the result in bucket arrays only if bucket_idx > 0
// The final 1D index for the bucket is id + subtask_offset * half_columns
const uint bi = id + subtask_offset * half_columns;
+
+ // Add bounds checking for bucket array access
+ const uint bucket_size = half_columns * NUM_LIMBS * num_subtasks;
if (bucket_idx > 0) {
// If j == 1, add to the existing bucket at index `bi`.
if (j == 1) {
diff --git a/mopro-msm/src/msm/metal_msm/shader/cuzk/transpose.metal b/mopro-msm/src/msm/metal_msm/shader/cuzk/transpose.metal
index 3b8078b1..b6edf74d 100644
--- a/mopro-msm/src/msm/metal_msm/shader/cuzk/transpose.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/cuzk/transpose.metal
@@ -6,45 +6,42 @@
using namespace metal;
kernel void transpose(
- device const uint* all_csr_col_idx [[buffer(0)]],
- device atomic_uint* all_csc_col_ptr [[buffer(1)]],
- device uint* all_csc_val_idxs [[buffer(2)]],
- device uint* all_curr [[buffer(3)]],
- constant uint2& params [[buffer(4)]],
- uint gid [[thread_position_in_grid]]
-) {
+ device const uint* all_csr_col_idx [[buffer(0), access(read)]],
+ device atomic_uint* all_csc_col_ptr [[buffer(1), access(read_write)]],
+ device uint* all_csc_val_idxs [[buffer(2), access(read_write)]],
+ device uint* all_curr [[buffer(3), access(read_write)]],
+ constant uint2& params [[buffer(4), access(read)]],
+ uint gid [[thread_position_in_grid]])
+{
const uint subtask_idx = gid;
- const uint n = params.x; // Number of columns
- const uint input_size = params.y; // Input size
-
+ const uint n = params.x; // Number of columns
+ const uint input_size = params.y; // Input size
+
// Calculate buffer offsets for this subtask
const uint ccp_offset = subtask_idx * (n + 1u);
const uint cci_offset = subtask_idx * input_size;
const uint curr_offset = subtask_idx * n;
-
+
// Phase 1: Count non-zero elements in each column
for (uint j = 0u; j < input_size; j++) {
atomic_fetch_add_explicit(
- &all_csc_col_ptr[ccp_offset + all_csr_col_idx[cci_offset + j] + 1u],
- 1u,
- memory_order_relaxed
- );
+ &all_csc_col_ptr[ccp_offset + all_csr_col_idx[cci_offset + j] + 1u],
+ 1u,
+ memory_order_relaxed);
}
-
+
// Phase 2: Prefix sum for column pointers
for (uint i = 1u; i < n + 1u; i++) {
const uint incremental_sum = atomic_load_explicit(
- &all_csc_col_ptr[ccp_offset + i - 1u],
- memory_order_relaxed
- );
+ &all_csc_col_ptr[ccp_offset + i - 1u],
+ memory_order_relaxed);
atomic_fetch_add_explicit(
- &all_csc_col_ptr[ccp_offset + i],
- incremental_sum,
- memory_order_relaxed
- );
+ &all_csc_col_ptr[ccp_offset + i],
+ incremental_sum,
+ memory_order_relaxed);
}
-
+
// Phase 3: Rearrange elements into CSC format
/// "Traverse the nonzero elements again and move them to their final
/// positions determined by the column offsets in csc_col_ptr and their
@@ -52,16 +49,15 @@ kernel void transpose(
uint val = 0u;
for (uint j = 0; j < input_size; j++) {
const uint col = all_csr_col_idx[cci_offset + j];
-
+
// Get current position for this column
- uint loc = atomic_load_explicit(
- &all_csc_col_ptr[ccp_offset + col],
- memory_order_relaxed
- );
+ const uint loc = atomic_load_explicit(
+ &all_csc_col_ptr[ccp_offset + col],
+ memory_order_relaxed)
+ + all_curr[curr_offset + col];
- loc += all_curr[curr_offset + col];
all_curr[curr_offset + col]++;
-
+
// Store the value index in CSC format
all_csc_val_idxs[cci_offset + loc] = val;
val++;
diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal
index 43393364..c4c5ac7f 100644
--- a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal
@@ -2,28 +2,31 @@
#pragma once
using namespace metal;
-#include
-#include
#include "../bigint/bigint.metal"
+#include
+#include
-inline BigInt ff_reduce(BigInt a) {
+inline BigInt ff_reduce(BigInt a)
+{
BigInt p = MODULUS;
BigIntResult res = bigint_sub(a, p);
- if (res.carry == 1) return a;
+ if (res.carry == 1)
+ return a;
return res.value;
}
-inline BigInt ff_add(BigInt a, BigInt b) {
+inline BigInt ff_add(BigInt a, BigInt b)
+{
return ff_reduce(bigint_add_unsafe(a, b).value);
}
-inline BigInt ff_sub(BigInt a, BigInt b) {
+inline BigInt ff_sub(BigInt a, BigInt b)
+{
bool a_gte_b = bigint_gte(a, b);
if (a_gte_b) {
return bigint_sub(a, b).value;
- }
- else {
+ } else {
// p - (b - a)
BigInt p = MODULUS;
BigIntResult diff = bigint_sub(b, a);
diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal
index 4e4e7ac2..583e06e6 100644
--- a/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal
@@ -1,15 +1,15 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "ff.metal"
+#include
+#include
kernel void test_ff_add(
- device BigInt* a [[ buffer(0) ]],
- device BigInt* b [[ buffer(1) ]],
- device BigInt* res [[ buffer(2) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a [[buffer(0)]],
+ device BigInt* b [[buffer(1)]],
+ device BigInt* res [[buffer(2)]],
+ uint gid [[thread_position_in_grid]])
+{
*res = ff_add(*a, *b);
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal
index 7db86a2c..d82db0b0 100644
--- a/mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_reduce.metal
@@ -1,14 +1,14 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "ff.metal"
+#include
+#include
kernel void test_ff_reduce(
- device BigInt* a [[ buffer(0) ]],
- device BigInt* res [[ buffer(1) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a [[buffer(0)]],
+ device BigInt* res [[buffer(1)]],
+ uint gid [[thread_position_in_grid]])
+{
*res = ff_reduce(*a);
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal
index a51d9bdb..7a42a603 100644
--- a/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal
@@ -1,15 +1,15 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "ff.metal"
+#include
+#include
kernel void test_ff_sub(
- device BigInt* a [[ buffer(0) ]],
- device BigInt* b [[ buffer(1) ]],
- device BigInt* res [[ buffer(2) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* a [[buffer(0)]],
+ device BigInt* b [[buffer(1)]],
+ device BigInt* res [[buffer(2)]],
+ uint gid [[thread_position_in_grid]])
+{
*res = ff_sub(*a, *b);
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/misc/get_constant.metal b/mopro-msm/src/msm/metal_msm/shader/misc/get_constant.metal
index 84cc7b0e..98d312e8 100644
--- a/mopro-msm/src/msm/metal_msm/shader/misc/get_constant.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/misc/get_constant.metal
@@ -1,26 +1,26 @@
#pragma once
using namespace metal;
-#include
#include "../misc/types.metal"
+#include
#define INIT_LIMB(i) BN254_BASEFIELD_MODULUS[i]
// for 16-bit mont_mul_cios (optimised for BN254)
-#define LIMBS_16 INIT_LIMB(0), INIT_LIMB(1), INIT_LIMB(2), INIT_LIMB(3), \
- INIT_LIMB(4), INIT_LIMB(5), INIT_LIMB(6), INIT_LIMB(7), \
+#define LIMBS_16 INIT_LIMB(0), INIT_LIMB(1), INIT_LIMB(2), INIT_LIMB(3), \
+ INIT_LIMB(4), INIT_LIMB(5), INIT_LIMB(6), INIT_LIMB(7), \
INIT_LIMB(8), INIT_LIMB(9), INIT_LIMB(10), INIT_LIMB(11), \
INIT_LIMB(12), INIT_LIMB(13), INIT_LIMB(14), INIT_LIMB(15)
// for 15-bit mont_mul_modified
-#define LIMBS_17 INIT_LIMB(0), INIT_LIMB(1), INIT_LIMB(2), INIT_LIMB(3), \
- INIT_LIMB(4), INIT_LIMB(5), INIT_LIMB(6), INIT_LIMB(7), \
- INIT_LIMB(8), INIT_LIMB(9), INIT_LIMB(10), INIT_LIMB(11), \
+#define LIMBS_17 INIT_LIMB(0), INIT_LIMB(1), INIT_LIMB(2), INIT_LIMB(3), \
+ INIT_LIMB(4), INIT_LIMB(5), INIT_LIMB(6), INIT_LIMB(7), \
+ INIT_LIMB(8), INIT_LIMB(9), INIT_LIMB(10), INIT_LIMB(11), \
INIT_LIMB(12), INIT_LIMB(13), INIT_LIMB(14), INIT_LIMB(15), \
INIT_LIMB(16)
// for 13-bit mont_mul_optimised
-#define LIMBS_20 INIT_LIMB(0), INIT_LIMB(1), INIT_LIMB(2), INIT_LIMB(3), \
- INIT_LIMB(4), INIT_LIMB(5), INIT_LIMB(6), INIT_LIMB(7), \
- INIT_LIMB(8), INIT_LIMB(9), INIT_LIMB(10), INIT_LIMB(11), \
+#define LIMBS_20 INIT_LIMB(0), INIT_LIMB(1), INIT_LIMB(2), INIT_LIMB(3), \
+ INIT_LIMB(4), INIT_LIMB(5), INIT_LIMB(6), INIT_LIMB(7), \
+ INIT_LIMB(8), INIT_LIMB(9), INIT_LIMB(10), INIT_LIMB(11), \
INIT_LIMB(12), INIT_LIMB(13), INIT_LIMB(14), INIT_LIMB(15), \
INIT_LIMB(16), INIT_LIMB(17), INIT_LIMB(18), INIT_LIMB(19)
@@ -35,17 +35,23 @@ constant BigInt MODULUS = {
#endif
};
-BigInt get_mu() {
+BigInt get_mu()
+{
BigInt mu;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
mu.limbs[i] = BARRETT_MU[i];
}
return mu;
}
-BigInt get_n0() {
+BigInt get_n0()
+{
BigInt n0;
uint tmp = N0;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
n0.limbs[i] = tmp & MASK;
tmp >>= LOG_LIMB_SIZE;
@@ -53,56 +59,77 @@ BigInt get_n0() {
return n0;
}
-BigIntWide get_r() {
+BigIntWide get_r()
+{
BigIntWide r;
+
+#pragma unroll(17)
for (uint i = 0; i < NUM_LIMBS_WIDE; i++) {
r.limbs[i] = MONT_RADIX[i];
}
return r;
}
-BigIntWide get_p_wide() {
+BigIntWide get_p_wide()
+{
BigIntWide p;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
p.limbs[i] = BN254_BASEFIELD_MODULUS[i];
}
return p;
}
-BigIntExtraWide get_p_extra_wide() {
+BigIntExtraWide get_p_extra_wide()
+{
BigIntExtraWide p;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
p.limbs[i] = BN254_BASEFIELD_MODULUS[i];
}
return p;
}
-BigInt bigint_zero() {
+BigInt bigint_zero()
+{
BigInt s;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
s.limbs[i] = 0;
}
return s;
}
-BigIntWide bigint_zero_wide() {
+BigIntWide bigint_zero_wide()
+{
BigIntWide s;
+
+#pragma unroll(17)
for (uint i = 0; i < NUM_LIMBS_WIDE; i++) {
s.limbs[i] = 0;
}
return s;
}
-BigIntExtraWide bigint_zero_extra_wide() {
+BigIntExtraWide bigint_zero_extra_wide()
+{
BigIntExtraWide s;
+
+#pragma unroll(32)
for (uint i = 0; i < NUM_LIMBS_EXTRA_WIDE; i++) {
s.limbs[i] = 0;
}
return s;
}
-Jacobian get_bn254_zero() {
+Jacobian get_bn254_zero()
+{
Jacobian zero;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
zero.x.limbs[i] = BN254_ZERO_X[i];
zero.y.limbs[i] = BN254_ZERO_Y[i];
@@ -111,8 +138,11 @@ Jacobian get_bn254_zero() {
return zero;
}
-Jacobian get_bn254_one() {
+Jacobian get_bn254_one()
+{
Jacobian one;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
one.x.limbs[i] = BN254_ONE_X[i];
one.y.limbs[i] = BN254_ONE_Y[i];
@@ -121,8 +151,11 @@ Jacobian get_bn254_one() {
return one;
}
-Jacobian get_bn254_zero_mont() {
+Jacobian get_bn254_zero_mont()
+{
Jacobian zero;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
zero.x.limbs[i] = BN254_ZERO_XR[i];
zero.y.limbs[i] = BN254_ZERO_YR[i];
@@ -131,8 +164,11 @@ Jacobian get_bn254_zero_mont() {
return zero;
}
-Jacobian get_bn254_one_mont() {
+Jacobian get_bn254_one_mont()
+{
Jacobian one;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
one.x.limbs[i] = BN254_ONE_XR[i];
one.y.limbs[i] = BN254_ONE_YR[i];
diff --git a/mopro-msm/src/msm/metal_msm/shader/misc/test_get_constant.metal b/mopro-msm/src/msm/metal_msm/shader/misc/test_get_constant.metal
index b89b2310..43bfe6f9 100644
--- a/mopro-msm/src/msm/metal_msm/shader/misc/test_get_constant.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/misc/test_get_constant.metal
@@ -1,49 +1,58 @@
using namespace metal;
-#include
#include "get_constant.metal"
+#include
-kernel void test_get_mu(device BigInt* result) {
+kernel void test_get_mu(device BigInt* result)
+{
*result = get_mu();
}
-kernel void test_get_n0(device BigInt* result) {
+kernel void test_get_n0(device BigInt* result)
+{
*result = get_n0();
}
-kernel void test_get_r(device BigIntWide* result) {
+kernel void test_get_r(device BigIntWide* result)
+{
*result = get_r();
}
-kernel void test_get_p(device BigInt* result) {
+kernel void test_get_p(device BigInt* result)
+{
*result = MODULUS;
}
-kernel void test_get_p_wide(device BigIntWide* result) {
+kernel void test_get_p_wide(device BigIntWide* result)
+{
*result = get_p_wide();
}
-kernel void test_get_bn254_zero(device BigInt* result_x, device BigInt* result_y, device BigInt* result_z) {
+kernel void test_get_bn254_zero(device BigInt* result_x, device BigInt* result_y, device BigInt* result_z)
+{
Jacobian result = get_bn254_zero();
*result_x = result.x;
*result_y = result.y;
*result_z = result.z;
}
-kernel void test_get_bn254_one(device BigInt* result_x, device BigInt* result_y, device BigInt* result_z) {
+kernel void test_get_bn254_one(device BigInt* result_x, device BigInt* result_y, device BigInt* result_z)
+{
Jacobian result = get_bn254_one();
*result_x = result.x;
*result_y = result.y;
*result_z = result.z;
}
-kernel void test_get_bn254_zero_mont(device BigInt* result_x, device BigInt* result_y, device BigInt* result_z) {
+kernel void test_get_bn254_zero_mont(device BigInt* result_x, device BigInt* result_y, device BigInt* result_z)
+{
Jacobian result = get_bn254_zero_mont();
*result_x = result.x;
*result_y = result.y;
*result_z = result.z;
}
-kernel void test_get_bn254_one_mont(device BigInt* result_x, device BigInt* result_y, device BigInt* result_z) {
+kernel void test_get_bn254_one_mont(device BigInt* result_x, device BigInt* result_y, device BigInt* result_z)
+{
Jacobian result = get_bn254_one_mont();
*result_x = result.x;
*result_y = result.y;
diff --git a/mopro-msm/src/msm/metal_msm/shader/misc/types.metal b/mopro-msm/src/msm/metal_msm/shader/misc/types.metal
index dd3d8faf..04a7914a 100644
--- a/mopro-msm/src/msm/metal_msm/shader/misc/types.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/misc/types.metal
@@ -1,8 +1,8 @@
#pragma once
using namespace metal;
-#include
#include "../constants.metal"
+#include
struct BigInt {
array limbs;
diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal
index ca26a768..1de3d795 100644
--- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal
@@ -4,42 +4,47 @@
#pragma clang diagnostic ignored "-Wdivision-by-zero" // to avoid warning on debug build, but we should always know what NSAFE is
using namespace metal;
-#include
-#include
#include "../field/ff.metal"
+#include
+#include
BigInt conditional_reduce(
BigInt x,
- BigInt y
-) {
+ BigInt y)
+{
if (x >= y) {
return x - y;
}
-
return x;
}
/// An optimised variant of the Montgomery product algorithm from
/// https://github.com/mitschabaude/montgomery#13-x-30-bit-multiplication.
/// Known to work with 12 and 13-bit limbs.
-BigInt mont_mul_optimised(BigInt x, BigInt y) {
+BigInt mont_mul_optimised(BigInt x, BigInt y)
+{
BigInt p = MODULUS;
BigInt s = bigint_zero();
- for (uint i = 0; i < NUM_LIMBS; i ++) {
+
+#pragma unroll(16)
+ for (uint i = 0; i < NUM_LIMBS; i++) {
uint t = s.limbs[0] + x.limbs[i] * y.limbs[0];
uint tprime = t & MASK;
uint qi = (N0 * tprime) & MASK;
uint c = (t + qi * p.limbs[0]) >> LOG_LIMB_SIZE;
s.limbs[0] = s.limbs[1] + x.limbs[i] * y.limbs[1] + qi * p.limbs[1] + c;
- for (uint j = 2; j < NUM_LIMBS; j ++) {
+#pragma unroll(14)
+ for (uint j = 2; j < NUM_LIMBS; j++) {
s.limbs[j - 1] = s.limbs[j] + x.limbs[i] * y.limbs[j] + qi * p.limbs[j];
}
s.limbs[NUM_LIMBS - 2] = x.limbs[i] * y.limbs[NUM_LIMBS - 1] + qi * p.limbs[NUM_LIMBS - 1];
}
uint c = 0;
- for (uint i = 0; i < NUM_LIMBS; i ++) {
+
+#pragma unroll(16)
+ for (uint i = 0; i < NUM_LIMBS; i++) {
uint v = s.limbs[i] + c;
c = v >> LOG_LIMB_SIZE;
s.limbs[i] = v & MASK;
@@ -51,17 +56,20 @@ BigInt mont_mul_optimised(BigInt x, BigInt y) {
/// An modified variant of the Montgomery product algorithm from
/// https://github.com/mitschabaude/montgomery#13-x-30-bit-multiplication.
/// Known to work with 14 and 15-bit limbs.
-BigInt mont_mul_modified(BigInt x, BigInt y) {
+BigInt mont_mul_modified(BigInt x, BigInt y)
+{
BigInt p = MODULUS;
BigInt s = bigint_zero();
- for (uint i = 0; i < NUM_LIMBS; i ++) {
+#pragma unroll(16)
+ for (uint i = 0; i < NUM_LIMBS; i++) {
uint t = s.limbs[0] + x.limbs[i] * y.limbs[0];
uint tprime = t & MASK;
uint qi = (N0 * tprime) & MASK;
uint c = (t + qi * p.limbs[0]) >> LOG_LIMB_SIZE;
- for (uint j = 1; j < NUM_LIMBS - 1; j ++) {
+#pragma unroll(14)
+ for (uint j = 1; j < NUM_LIMBS - 1; j++) {
uint t = s.limbs[j] + x.limbs[i] * y.limbs[j] + qi * p.limbs[j];
if ((j - 1) % NSAFE == 0) {
t = t + c;
@@ -80,7 +88,9 @@ BigInt mont_mul_modified(BigInt x, BigInt y) {
}
uint c = 0;
- for (uint i = 0; i < NUM_LIMBS; i ++) {
+
+#pragma unroll(16)
+ for (uint i = 0; i < NUM_LIMBS; i++) {
uint v = s.limbs[i] + c;
c = v >> LOG_LIMB_SIZE;
s.limbs[i] = v & MASK;
@@ -92,15 +102,19 @@ BigInt mont_mul_modified(BigInt x, BigInt y) {
/// The CIOS method for Montgomery multiplication from Tolga Acar's thesis:
/// High-Speed Algorithms & Architectures For Number-Theoretic Cryptosystems
/// https://www.proquest.com/openview/1018972f191afe55443658b28041c118/1
-inline BigInt mont_mul_cios(BigInt x, BigInt y) {
+inline BigInt mont_mul_cios(BigInt x, BigInt y)
+{
BigInt p = MODULUS;
BigInt result;
- uint t[NUM_LIMBS + 2] = {0}; // Extra space for carries
-
+ uint t[NUM_LIMBS + 2] = { 0 }; // Extra space for carries
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
// Step 1: Multiply and add
uint c = 0;
+
+#pragma unroll(16)
for (uint j = 0; j < NUM_LIMBS; j++) {
uint r = t[j] + x.limbs[j] * y.limbs[i] + c;
c = r >> LOG_LIMB_SIZE;
@@ -115,6 +129,7 @@ inline BigInt mont_mul_cios(BigInt x, BigInt y) {
r = t[0] + m * p.limbs[0];
c = r >> LOG_LIMB_SIZE;
+#pragma unroll(15)
for (uint j = 1; j < NUM_LIMBS; j++) {
r = t[j] + m * p.limbs[j] + c;
c = r >> LOG_LIMB_SIZE;
@@ -129,6 +144,8 @@ inline BigInt mont_mul_cios(BigInt x, BigInt y) {
// Final reduction check
bool t_lt_p = false;
+
+#pragma unroll(16)
for (uint idx = 0; idx < NUM_LIMBS; idx++) {
uint i = NUM_LIMBS - 1 - idx;
if (t[i] < p.limbs[i]) {
@@ -140,11 +157,14 @@ inline BigInt mont_mul_cios(BigInt x, BigInt y) {
}
if (t_lt_p) {
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
result.limbs[i] = t[i];
}
} else {
uint borrow = 0;
+
+#pragma unroll(16)
for (uint i = 0; i < NUM_LIMBS; i++) {
uint diff = t[i] - p.limbs[i] - borrow;
if (t[i] < (p.limbs[i] + borrow)) {
diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios.metal
index 8d111df3..14f2ab86 100644
--- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios.metal
@@ -1,15 +1,15 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "mont.metal"
+#include
+#include
kernel void test_mont_mul_cios(
- device BigInt* lhs [[ buffer(0) ]],
- device BigInt* rhs [[ buffer(1) ]],
- device BigInt* result [[ buffer(2) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* lhs [[buffer(0)]],
+ device BigInt* rhs [[buffer(1)]],
+ device BigInt* result [[buffer(2)]],
+ uint gid [[thread_position_in_grid]])
+{
*result = mont_mul_cios(*lhs, *rhs);
}
diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal
index bd2a01a6..e77a1a3f 100644
--- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal
@@ -1,21 +1,21 @@
using namespace metal;
-#include
-#include
#include "mont.metal"
+#include
+#include
kernel void test_mont_mul_cios_benchmarks(
- device BigInt* lhs [[ buffer(0) ]],
- device BigInt* rhs [[ buffer(1) ]],
- device array* cost [[ buffer(2) ]],
- device BigInt* result [[ buffer(3) ]],
- uint gid [[ thread_position_in_grid ]]
-) {
+ device BigInt* lhs [[buffer(0)]],
+ device BigInt* rhs [[buffer(1)]],
+ device array* cost [[buffer(2)]],
+ device BigInt* result [[buffer(3)]],
+ uint gid [[thread_position_in_grid]])
+{
BigInt a = *lhs;
BigInt b = *rhs;
array cost_arr = *cost;
BigInt c = mont_mul_cios(a, a);
- for (uint i = 1; i < cost_arr[0]; i ++) {
+ for (uint i = 1; i < cost_arr[0]; i++) {
c = mont_mul_cios(c, a);
}
*result = mont_mul_cios(c, b);
diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal
index 0be881f9..7887f7a7 100644
--- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal
+++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal
@@ -1,15 +1,15 @@
// source: https://github.com/geometryxyz/msl-secp256k1
using namespace metal;
-#include
-#include
#include "mont.metal"
+#include