From 0f14fd6b5c64652b75c31a61768e446e3e7e3f83 Mon Sep 17 00:00:00 2001 From: moven0831 Date: Mon, 2 Jun 2025 12:39:14 +0800 Subject: [PATCH 1/2] refactor(metal_msm): unify buffer creation methods --- mopro-msm/.gitignore | 3 + mopro-msm/src/msm/metal_msm/metal_msm.rs | 68 +++++++------ .../tests/bigint/bigint_add_unsafe.rs | 9 +- .../metal_msm/tests/bigint/bigint_add_wide.rs | 9 +- .../msm/metal_msm/tests/bigint/bigint_sub.rs | 9 +- .../tests/curve/jacobian_add_2007_b1.rs | 31 +++--- .../tests/curve/jacobian_dbl_2009_l.rs | 22 +++-- .../tests/curve/jacobian_madd_2007_bl.rs | 28 ++++-- .../msm/metal_msm/tests/curve/jacobian_neg.rs | 22 +++-- .../tests/curve/jacobian_scalar_mul.rs | 9 +- .../metal_msm/tests/cuzk/barrett_reduction.rs | 16 ++- ...vert_point_coords_and_decompose_scalars.rs | 36 ++++--- .../src/msm/metal_msm/tests/cuzk/pbpr.rs | 20 ++-- .../src/msm/metal_msm/tests/cuzk/smvp.rs | 20 ++-- .../src/msm/metal_msm/tests/cuzk/transpose.rs | 99 +++++++++---------- .../src/msm/metal_msm/tests/field/ff_add.rs | 9 +- .../msm/metal_msm/tests/field/ff_reduce.rs | 14 ++- .../src/msm/metal_msm/tests/field/ff_sub.rs | 9 +- .../msm/metal_msm/tests/misc/get_constant.rs | 43 ++++---- .../tests/mont_backend/mont_benchmarks.rs | 13 +-- .../tests/mont_backend/mont_mul_cios.rs | 11 +-- .../tests/mont_backend/mont_mul_modified.rs | 11 +-- .../tests/mont_backend/mont_mul_optimised.rs | 11 +-- .../src/msm/metal_msm/utils/metal_wrapper.rs | 19 ++-- 24 files changed, 262 insertions(+), 279 deletions(-) diff --git a/mopro-msm/.gitignore b/mopro-msm/.gitignore index 2a2b775c..4d30bfb7 100644 --- a/mopro-msm/.gitignore +++ b/mopro-msm/.gitignore @@ -22,3 +22,6 @@ proptest-regressions # Metal shader intermediate files and libraries src/msm/metal_msm/shader/**/*.ir src/msm/metal_msm/shader/**/*.lib + +# Ignore constants.metal +src/msm/metal_msm/shader/constants.metal diff --git a/mopro-msm/src/msm/metal_msm/metal_msm.rs b/mopro-msm/src/msm/metal_msm/metal_msm.rs index b22e1abf..c6fc0503 100644 --- a/mopro-msm/src/msm/metal_msm/metal_msm.rs +++ b/mopro-msm/src/msm/metal_msm/metal_msm.rs @@ -266,15 +266,15 @@ impl ConvertPointAndScalarDecompose { ) -> Result<(Vec, Vec, Vec), Box> { let mut helper = MetalHelper::new(); - let coords_buf = helper.create_input_buffer(&coords.to_vec()); - let scalars_buf = helper.create_input_buffer(&scalars.to_vec()); + let coords_buf = helper.create_buffer(&coords.to_vec()); + let scalars_buf = helper.create_buffer(&scalars.to_vec()); - let out_point_x = helper.create_output_buffer(input_size * self.config.num_limbs); - let out_point_y = helper.create_output_buffer(input_size * self.config.num_limbs); - let out_scalar_chunks = helper.create_output_buffer(input_size * num_subtasks); + let out_point_x = helper.create_empty_buffer(input_size * self.config.num_limbs); + let out_point_y = helper.create_empty_buffer(input_size * self.config.num_limbs); + let out_scalar_chunks = helper.create_empty_buffer(input_size * num_subtasks); - let input_size_buf = helper.create_input_buffer(&vec![input_size as u32]); - let num_y_workgroups_buf = helper.create_input_buffer(&vec![c_num_y_workgroups as u32]); + let input_size_buf = helper.create_buffer(&vec![input_size as u32]); + let num_y_workgroups_buf = helper.create_buffer(&vec![c_num_y_workgroups as u32]); let thread_group_count = helper.create_thread_group_size( c_num_x_workgroups as u64, @@ -286,8 +286,10 @@ impl ConvertPointAndScalarDecompose { helper.execute_shader( &self.config, - &[&coords_buf, &scalars_buf, &input_size_buf], &[ + &coords_buf, + &scalars_buf, + &input_size_buf, &out_point_x, &out_point_y, &out_scalar_chunks, @@ -337,14 +339,14 @@ impl Transpose { ) -> Result<(Vec, Vec), Box> { let mut helper = MetalHelper::new(); - let in_chunks_buf = helper.create_input_buffer(&scalar_chunks.to_vec()); + let in_chunks_buf = helper.create_buffer(&scalar_chunks.to_vec()); let out_csc_col_ptr = - helper.create_output_buffer(num_subtasks * ((num_columns + 1) as usize) * 4); - let out_csc_val_idxs = helper.create_output_buffer(scalar_chunks.len()); - let out_curr = helper.create_output_buffer(num_subtasks * (num_columns as usize) * 4); + helper.create_empty_buffer(num_subtasks * ((num_columns + 1) as usize) * 4); + let out_csc_val_idxs = helper.create_empty_buffer(scalar_chunks.len()); + let out_curr = helper.create_empty_buffer(num_subtasks * (num_columns as usize) * 4); let params = vec![num_columns, input_size as u32]; - let params_buf = helper.create_input_buffer(¶ms); + let params_buf = helper.create_buffer(¶ms); let thread_group_count = helper.create_thread_group_size( t_num_x_workgroups as u64, @@ -363,7 +365,6 @@ impl Transpose { &out_curr, ¶ms_buf, ], - &[], &thread_group_count, &threads_per_threadgroup, ); @@ -416,14 +417,14 @@ impl SMVP { let bucket_size = (num_columns / 2) as usize * self.config.num_limbs * 4 * num_subtasks; // Create buffers - let row_ptr_buf = helper.create_input_buffer(&csc_col_ptr.to_vec()); - let val_idx_buf = helper.create_input_buffer(&csc_val_idxs.to_vec()); - let point_x_buf = helper.create_input_buffer(&point_x.to_vec()); - let point_y_buf = helper.create_input_buffer(&point_y.to_vec()); + let row_ptr_buf = helper.create_buffer(&csc_col_ptr.to_vec()); + let val_idx_buf = helper.create_buffer(&csc_val_idxs.to_vec()); + let point_x_buf = helper.create_buffer(&point_x.to_vec()); + let point_y_buf = helper.create_buffer(&point_y.to_vec()); - let bucket_x_buf = helper.create_output_buffer(bucket_size); - let bucket_y_buf = helper.create_output_buffer(bucket_size); - let bucket_z_buf = helper.create_output_buffer(bucket_size); + let bucket_x_buf = helper.create_empty_buffer(bucket_size); + let bucket_y_buf = helper.create_empty_buffer(bucket_size); + let bucket_z_buf = helper.create_empty_buffer(bucket_size); // Execute in chunks let num_subtask_chunk_size = 4u32; @@ -434,7 +435,7 @@ impl SMVP { num_subtasks as u32, offset, ]; - let params_buf = helper.create_input_buffer(¶ms); + let params_buf = helper.create_buffer(¶ms); let adjusted_x_workgroups = s_num_x_workgroups / (num_subtasks / num_subtask_chunk_size as usize); @@ -459,7 +460,6 @@ impl SMVP { &bucket_z_buf, ¶ms_buf, ], - &[], &thread_group_count, &threads_per_threadgroup, ); @@ -518,14 +518,14 @@ impl PBPR { ) -> Result<(Vec, Vec, Vec), Box> { let mut helper = MetalHelper::new(); - let bucket_sum_x_buf = helper.create_input_buffer(&bucket_x.to_vec()); - let bucket_sum_y_buf = helper.create_input_buffer(&bucket_y.to_vec()); - let bucket_sum_z_buf = helper.create_input_buffer(&bucket_z.to_vec()); + let bucket_sum_x_buf = helper.create_buffer(&bucket_x.to_vec()); + let bucket_sum_y_buf = helper.create_buffer(&bucket_y.to_vec()); + let bucket_sum_z_buf = helper.create_buffer(&bucket_z.to_vec()); let g_points_size = num_subtasks * b_workgroup_size * self.stage1_config.num_limbs * 4; - let g_points_x_buf = helper.create_output_buffer(g_points_size); - let g_points_y_buf = helper.create_output_buffer(g_points_size); - let g_points_z_buf = helper.create_output_buffer(g_points_size); + let g_points_x_buf = helper.create_empty_buffer(g_points_size); + let g_points_y_buf = helper.create_empty_buffer(g_points_size); + let g_points_z_buf = helper.create_empty_buffer(g_points_size); // Stage 1 for subtask_chunk_idx in (0..num_subtasks).step_by(num_subtasks_per_bpr_1) { @@ -534,8 +534,8 @@ impl PBPR { num_columns, num_subtasks_per_bpr_1 as u32, ]; - let params_buf = helper.create_input_buffer(¶ms); - let workgroup_size_buf = helper.create_input_buffer(&vec![b_workgroup_size as u32]); + let params_buf = helper.create_buffer(¶ms); + let workgroup_size_buf = helper.create_buffer(&vec![b_workgroup_size as u32]); let stage1_thread_group_count = helper.create_thread_group_size( b_num_x_workgroups as u64, @@ -557,7 +557,6 @@ impl PBPR { ¶ms_buf, &workgroup_size_buf, ], - &[], &stage1_thread_group_count, &stage1_threads_per_threadgroup, ); @@ -570,8 +569,8 @@ impl PBPR { num_columns, num_subtasks_per_bpr_2 as u32, ]; - let params_buf = helper.create_input_buffer(¶ms); - let workgroup_size_buf = helper.create_input_buffer(&vec![b_workgroup_size as u32]); + let params_buf = helper.create_buffer(¶ms); + let workgroup_size_buf = helper.create_buffer(&vec![b_workgroup_size as u32]); let stage2_thread_group_count = helper.create_thread_group_size( b_2_num_x_workgroups as u64, @@ -593,7 +592,6 @@ impl PBPR { ¶ms_buf, &workgroup_size_buf, ], - &[], &stage2_thread_group_count, &stage2_threads_per_threadgroup, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs index 283847f5..758354b8 100644 --- a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs @@ -29,17 +29,16 @@ pub fn test_bigint_add_unsafe() { } }; - let a_buf = helper.create_input_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); - let b_buf = helper.create_input_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); - let result_buf = helper.create_output_buffer(config.num_limbs); + let a_buf = helper.create_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); + let b_buf = helper.create_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); + let result_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_size = helper.create_thread_group_size(1, 1, 1); let thread_group_count = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[&a_buf, &b_buf], - &[&result_buf], + &[&a_buf, &b_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs index f7bff21c..018f4ac7 100644 --- a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs @@ -43,17 +43,16 @@ fn run_bigint_add_test(a: &BigInt<4>, b: &BigInt<4>, expected: &BigInt<4>) { let mut helper = MetalHelper::new(); - let a_buf = helper.create_input_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); - let b_buf = helper.create_input_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); - let result_buf = helper.create_output_buffer(config.num_limbs + 1); + let a_buf = helper.create_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); + let b_buf = helper.create_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); + let result_buf = helper.create_empty_buffer(config.num_limbs + 1); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[&a_buf, &b_buf], - &[&result_buf], + &[&a_buf, &b_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs index 8da0c6ee..753f3e7c 100644 --- a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs @@ -28,17 +28,16 @@ fn run_bigint_sub_test(a: BigInt<4>, b: BigInt<4>, expected: BigInt<4>) { let mut helper = MetalHelper::new(); - let a_buf = helper.create_input_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); - let b_buf = helper.create_input_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); - let result_buf = helper.create_output_buffer(config.num_limbs); + let a_buf = helper.create_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); + let b_buf = helper.create_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); + let result_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[&a_buf, &b_buf], - &[&result_buf], + &[&a_buf, &b_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_add_2007_b1.rs b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_add_2007_b1.rs index 015c1305..faee7d31 100644 --- a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_add_2007_b1.rs +++ b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_add_2007_b1.rs @@ -61,15 +61,15 @@ fn jacobian_add_2007_bl_kernel(a: G, b: G, shader_name: &str) -> G { .to_limbs(num_limbs, log_limb_size); // Create buffers - let axr_buf = helper.create_input_buffer(&axr_limbs); - let ayr_buf = helper.create_input_buffer(&ayr_limbs); - let azr_buf = helper.create_input_buffer(&azr_limbs); - let bxr_buf = helper.create_input_buffer(&bxr_limbs); - let byr_buf = helper.create_input_buffer(&byr_limbs); - let bzr_buf = helper.create_input_buffer(&bzr_limbs); - let result_xr_buf = helper.create_output_buffer(num_limbs); - let result_yr_buf = helper.create_output_buffer(num_limbs); - let result_zr_buf = helper.create_output_buffer(num_limbs); + let axr_buf = helper.create_buffer(&axr_limbs); + let ayr_buf = helper.create_buffer(&ayr_limbs); + let azr_buf = helper.create_buffer(&azr_limbs); + let bxr_buf = helper.create_buffer(&bxr_limbs); + let byr_buf = helper.create_buffer(&byr_limbs); + let bzr_buf = helper.create_buffer(&bzr_limbs); + let result_xr_buf = helper.create_empty_buffer(num_limbs); + let result_yr_buf = helper.create_empty_buffer(num_limbs); + let result_zr_buf = helper.create_empty_buffer(num_limbs); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -77,8 +77,17 @@ fn jacobian_add_2007_bl_kernel(a: G, b: G, shader_name: &str) -> G { helper.execute_shader( &config, - &[&axr_buf, &ayr_buf, &azr_buf, &bxr_buf, &byr_buf, &bzr_buf], - &[&result_xr_buf, &result_yr_buf, &result_zr_buf], + &[ + &axr_buf, + &ayr_buf, + &azr_buf, + &bxr_buf, + &byr_buf, + &bzr_buf, + &result_xr_buf, + &result_yr_buf, + &result_zr_buf, + ], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_dbl_2009_l.rs b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_dbl_2009_l.rs index c10c0521..4f49700e 100644 --- a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_dbl_2009_l.rs +++ b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_dbl_2009_l.rs @@ -52,12 +52,12 @@ pub fn test_jacobian_dbl_2009_l() { .to_limbs(num_limbs, log_limb_size); // Create buffers - let axr_buf = helper.create_input_buffer(&axr_limbs); - let ayr_buf = helper.create_input_buffer(&ayr_limbs); - let azr_buf = helper.create_input_buffer(&azr_limbs); - let result_xr_buf = helper.create_output_buffer(num_limbs); - let result_yr_buf = helper.create_output_buffer(num_limbs); - let result_zr_buf = helper.create_output_buffer(num_limbs); + let axr_buf = helper.create_buffer(&axr_limbs); + let ayr_buf = helper.create_buffer(&ayr_limbs); + let azr_buf = helper.create_buffer(&azr_limbs); + let result_xr_buf = helper.create_empty_buffer(num_limbs); + let result_yr_buf = helper.create_empty_buffer(num_limbs); + let result_zr_buf = helper.create_empty_buffer(num_limbs); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -65,8 +65,14 @@ pub fn test_jacobian_dbl_2009_l() { helper.execute_shader( &config, - &[&axr_buf, &ayr_buf, &azr_buf], - &[&result_xr_buf, &result_yr_buf, &result_zr_buf], + &[ + &axr_buf, + &ayr_buf, + &azr_buf, + &result_xr_buf, + &result_yr_buf, + &result_zr_buf, + ], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_madd_2007_bl.rs b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_madd_2007_bl.rs index ab9b7622..7c929277 100644 --- a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_madd_2007_bl.rs +++ b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_madd_2007_bl.rs @@ -78,14 +78,14 @@ pub fn test_jacobian_madd_2007_bl() { .to_limbs(num_limbs, log_limb_size); // Create buffers - let axr_buf = helper.create_input_buffer(&axr_limbs); - let ayr_buf = helper.create_input_buffer(&ayr_limbs); - let azr_buf = helper.create_input_buffer(&azr_limbs); - let bxr_buf = helper.create_input_buffer(&bxr_limbs); - let byr_buf = helper.create_input_buffer(&byr_limbs); - let result_xr_buf = helper.create_output_buffer(num_limbs); - let result_yr_buf = helper.create_output_buffer(num_limbs); - let result_zr_buf = helper.create_output_buffer(num_limbs); + let axr_buf = helper.create_buffer(&axr_limbs); + let ayr_buf = helper.create_buffer(&ayr_limbs); + let azr_buf = helper.create_buffer(&azr_limbs); + let bxr_buf = helper.create_buffer(&bxr_limbs); + let byr_buf = helper.create_buffer(&byr_limbs); + let result_xr_buf = helper.create_empty_buffer(num_limbs); + let result_yr_buf = helper.create_empty_buffer(num_limbs); + let result_zr_buf = helper.create_empty_buffer(num_limbs); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -93,8 +93,16 @@ pub fn test_jacobian_madd_2007_bl() { helper.execute_shader( &config, - &[&axr_buf, &ayr_buf, &azr_buf, &bxr_buf, &byr_buf], - &[&result_xr_buf, &result_yr_buf, &result_zr_buf], + &[ + &axr_buf, + &ayr_buf, + &azr_buf, + &bxr_buf, + &byr_buf, + &result_xr_buf, + &result_yr_buf, + &result_zr_buf, + ], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_neg.rs b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_neg.rs index 7b3e8b28..9e3d35c8 100644 --- a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_neg.rs +++ b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_neg.rs @@ -52,12 +52,12 @@ pub fn test_jacobian_neg() { .to_limbs(num_limbs, log_limb_size); // Create buffers - let axr_buf = helper.create_input_buffer(&axr_limbs); - let ayr_buf = helper.create_input_buffer(&ayr_limbs); - let azr_buf = helper.create_input_buffer(&azr_limbs); - let result_xr_buf = helper.create_output_buffer(num_limbs); - let result_yr_buf = helper.create_output_buffer(num_limbs); - let result_zr_buf = helper.create_output_buffer(num_limbs); + let axr_buf = helper.create_buffer(&axr_limbs); + let ayr_buf = helper.create_buffer(&ayr_limbs); + let azr_buf = helper.create_buffer(&azr_limbs); + let result_xr_buf = helper.create_empty_buffer(num_limbs); + let result_yr_buf = helper.create_empty_buffer(num_limbs); + let result_zr_buf = helper.create_empty_buffer(num_limbs); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -65,8 +65,14 @@ pub fn test_jacobian_neg() { helper.execute_shader( &config, - &[&axr_buf, &ayr_buf, &azr_buf], - &[&result_xr_buf, &result_yr_buf, &result_zr_buf], + &[ + &axr_buf, + &ayr_buf, + &azr_buf, + &result_xr_buf, + &result_yr_buf, + &result_zr_buf, + ], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_scalar_mul.rs b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_scalar_mul.rs index 802a197f..d945bd33 100644 --- a/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_scalar_mul.rs +++ b/mopro-msm/src/msm/metal_msm/tests/curve/jacobian_scalar_mul.rs @@ -44,7 +44,7 @@ fn jacobian_scalar_mul_kernel(point: G, scalar: u32, name: &str) -> G { .to_limbs(num_limbs, log_limb_size); // Create input buffers - let point_buf = helper.create_input_buffer( + let point_buf = helper.create_buffer( &[ axr_limbs.as_slice(), ayr_limbs.as_slice(), @@ -52,8 +52,8 @@ fn jacobian_scalar_mul_kernel(point: G, scalar: u32, name: &str) -> G { ] .concat(), ); - let scalar_buf = helper.create_input_buffer(&vec![scalar]); - let result_buf = helper.create_output_buffer(num_limbs * 3); + let scalar_buf = helper.create_buffer(&vec![scalar]); + let result_buf = helper.create_empty_buffer(num_limbs * 3); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -61,8 +61,7 @@ fn jacobian_scalar_mul_kernel(point: G, scalar: u32, name: &str) -> G { helper.execute_shader( &config, - &[&point_buf, &scalar_buf], - &[&result_buf], + &[&point_buf, &scalar_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/cuzk/barrett_reduction.rs b/mopro-msm/src/msm/metal_msm/tests/cuzk/barrett_reduction.rs index c85cc957..bc8bd77d 100644 --- a/mopro-msm/src/msm/metal_msm/tests/cuzk/barrett_reduction.rs +++ b/mopro-msm/src/msm/metal_msm/tests/cuzk/barrett_reduction.rs @@ -37,8 +37,8 @@ pub fn test_barrett_reduce_with_mont_params() { let mont_a_limbs = mont_a_in_ark.to_limbs(num_limbs_extra_wide, log_limb_size); // Create buffers - let mont_a_buf = helper.create_input_buffer(&mont_a_limbs); - let result_buf = helper.create_output_buffer(num_limbs); + let mont_a_buf = helper.create_buffer(&mont_a_limbs); + let result_buf = helper.create_empty_buffer(num_limbs); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -46,8 +46,7 @@ pub fn test_barrett_reduce_with_mont_params() { helper.execute_shader( &config, - &[&mont_a_buf], - &[&result_buf], + &[&mont_a_buf, &result_buf], &thread_group_count, &thread_group_size, ); @@ -91,9 +90,9 @@ pub fn test_field_mul_with_mont_params() { let r_limbs = r_in_ark.to_limbs(num_limbs_wide, log_limb_size); // Create buffers - let a_buf = helper.create_input_buffer(&a_limbs); - let r_buf = helper.create_input_buffer(&r_limbs); - let res_buf = helper.create_output_buffer(num_limbs); + let a_buf = helper.create_buffer(&a_limbs); + let r_buf = helper.create_buffer(&r_limbs); + let res_buf = helper.create_empty_buffer(num_limbs); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -101,8 +100,7 @@ pub fn test_field_mul_with_mont_params() { helper.execute_shader( &config, - &[&a_buf, &r_buf], - &[&res_buf], + &[&a_buf, &r_buf, &res_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/cuzk/convert_point_coords_and_decompose_scalars.rs b/mopro-msm/src/msm/metal_msm/tests/cuzk/convert_point_coords_and_decompose_scalars.rs index 685d407f..8dcb866e 100644 --- a/mopro-msm/src/msm/metal_msm/tests/cuzk/convert_point_coords_and_decompose_scalars.rs +++ b/mopro-msm/src/msm/metal_msm/tests/cuzk/convert_point_coords_and_decompose_scalars.rs @@ -59,15 +59,15 @@ fn test_point_coords_conversion() { let coords = [x_packed, y_packed].concat(); // Setup Metal buffers - let coords_buf = helper.create_input_buffer(&coords); - let scalars_buf = helper.create_input_buffer(&scalars); - let input_size_buf = helper.create_input_buffer(&vec![1u32]); - let num_y_workgroups_buf = helper.create_input_buffer(&vec![1u32]); + let coords_buf = helper.create_buffer(&coords); + let scalars_buf = helper.create_buffer(&scalars); + let input_size_buf = helper.create_buffer(&vec![1u32]); + let num_y_workgroups_buf = helper.create_buffer(&vec![1u32]); // Prepare output buffers for the kernel - let point_x_buf = helper.create_output_buffer(num_limbs); - let point_y_buf = helper.create_output_buffer(num_limbs); - let chunks_buf = helper.create_output_buffer(num_limbs); + let point_x_buf = helper.create_empty_buffer(num_limbs); + let point_y_buf = helper.create_empty_buffer(num_limbs); + let chunks_buf = helper.create_empty_buffer(num_limbs); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -76,8 +76,10 @@ fn test_point_coords_conversion() { // Execute the shader helper.execute_shader( &config, - &[&coords_buf, &scalars_buf, &input_size_buf], &[ + &coords_buf, + &scalars_buf, + &input_size_buf, &point_x_buf, &point_y_buf, &chunks_buf, @@ -133,15 +135,15 @@ fn test_scalar_decomposition() { let packed_scalars = pack_limbs(&scalars); // Setup Metal buffers - let coords_buf = helper.create_input_buffer(&coords); - let scalars_buf = helper.create_input_buffer(&packed_scalars); - let input_size_buf = helper.create_input_buffer(&vec![1u32]); - let num_y_workgroups_buf = helper.create_input_buffer(&vec![1u32]); + let coords_buf = helper.create_buffer(&coords); + let scalars_buf = helper.create_buffer(&packed_scalars); + let input_size_buf = helper.create_buffer(&vec![1u32]); + let num_y_workgroups_buf = helper.create_buffer(&vec![1u32]); // We'll ignore X,Y outputs, but we must pass them - let point_x_buf = helper.create_output_buffer(num_limbs); - let point_y_buf = helper.create_output_buffer(num_limbs); - let chunks_buf = helper.create_output_buffer(num_subtasks); + let point_x_buf = helper.create_empty_buffer(num_limbs); + let point_y_buf = helper.create_empty_buffer(num_limbs); + let chunks_buf = helper.create_empty_buffer(num_subtasks); // Setup thread group sizes let thread_group_count = helper.create_thread_group_size(1, 1, 1); @@ -150,8 +152,10 @@ fn test_scalar_decomposition() { // Execute the shader helper.execute_shader( &config, - &[&coords_buf, &scalars_buf, &input_size_buf], &[ + &coords_buf, + &scalars_buf, + &input_size_buf, &point_x_buf, &point_y_buf, &chunks_buf, diff --git a/mopro-msm/src/msm/metal_msm/tests/cuzk/pbpr.rs b/mopro-msm/src/msm/metal_msm/tests/cuzk/pbpr.rs index b22257f2..0538727f 100644 --- a/mopro-msm/src/msm/metal_msm/tests/cuzk/pbpr.rs +++ b/mopro-msm/src/msm/metal_msm/tests/cuzk/pbpr.rs @@ -74,16 +74,16 @@ fn test_pbpr_stage1_and_stage2() { //---------------------------------------------- let mut helper = MetalHelper::new(); - let bucket_sum_x_buf = helper.create_input_buffer(&bucket_sum_x_limbs); - let bucket_sum_y_buf = helper.create_input_buffer(&bucket_sum_y_limbs); - let bucket_sum_z_buf = helper.create_input_buffer(&bucket_sum_z_limbs); + let bucket_sum_x_buf = helper.create_buffer(&bucket_sum_x_limbs); + let bucket_sum_y_buf = helper.create_buffer(&bucket_sum_y_limbs); + let bucket_sum_z_buf = helper.create_buffer(&bucket_sum_z_limbs); - let g_points_x_buf = helper.create_input_buffer(&g_points_x_limbs); - let g_points_y_buf = helper.create_input_buffer(&g_points_y_limbs); - let g_points_z_buf = helper.create_input_buffer(&g_points_z_limbs); + let g_points_x_buf = helper.create_buffer(&g_points_x_limbs); + let g_points_y_buf = helper.create_buffer(&g_points_y_limbs); + let g_points_z_buf = helper.create_buffer(&g_points_z_limbs); let wg_size_vec = vec![workgroup_size]; - let wg_size_buf = helper.create_input_buffer(&wg_size_vec); + let wg_size_buf = helper.create_buffer(&wg_size_vec); let thread_group_count = helper.create_thread_group_size(num_subtasks_per_bpr as u64, 1, 1); let thread_group_size = helper.create_thread_group_size(workgroup_size as u64, 1, 1); @@ -110,7 +110,7 @@ fn test_pbpr_stage1_and_stage2() { num_columns, num_subtasks_per_bpr as u32, ]; - let params_buf = helper.create_input_buffer(¶ms); + let params_buf = helper.create_buffer(¶ms); helper.execute_shader( &config_stage1, &[ @@ -123,7 +123,6 @@ fn test_pbpr_stage1_and_stage2() { ¶ms_buf, &wg_size_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -134,7 +133,7 @@ fn test_pbpr_stage1_and_stage2() { num_columns, num_subtasks_per_bpr as u32, ]; - let params_buf = helper.create_input_buffer(¶ms); + let params_buf = helper.create_buffer(¶ms); helper.execute_shader( &config_stage2, &[ @@ -147,7 +146,6 @@ fn test_pbpr_stage1_and_stage2() { ¶ms_buf, &wg_size_buf, ], - &[], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/cuzk/smvp.rs b/mopro-msm/src/msm/metal_msm/tests/cuzk/smvp.rs index a40d4f51..44ca2cb8 100644 --- a/mopro-msm/src/msm/metal_msm/tests/cuzk/smvp.rs +++ b/mopro-msm/src/msm/metal_msm/tests/cuzk/smvp.rs @@ -49,16 +49,13 @@ fn smvp_gpu( let bucket_sum_coord_bytelength = (num_columns / 2) as usize * config.num_limbs as usize * 4 * num_subtasks as usize; - // Input buffers - let row_ptr_buf = helper.create_input_buffer(gpu_csc_col_ptr); - let val_idx_buf = helper.create_input_buffer(gpu_csc_val_idxs); - let point_x_buf = helper.create_input_buffer(gpu_point_x); - let point_y_buf = helper.create_input_buffer(gpu_point_y); - - // Output "bucket" buffers - let bucket_x_buf = helper.create_output_buffer(bucket_sum_coord_bytelength); - let bucket_y_buf = helper.create_output_buffer(bucket_sum_coord_bytelength); - let bucket_z_buf = helper.create_output_buffer(bucket_sum_coord_bytelength); + let row_ptr_buf = helper.create_buffer(gpu_csc_col_ptr); + let val_idx_buf = helper.create_buffer(gpu_csc_val_idxs); + let point_x_buf = helper.create_buffer(gpu_point_x); + let point_y_buf = helper.create_buffer(gpu_point_y); + let bucket_x_buf = helper.create_empty_buffer(bucket_sum_coord_bytelength); + let bucket_y_buf = helper.create_empty_buffer(bucket_sum_coord_bytelength); + let bucket_z_buf = helper.create_empty_buffer(bucket_sum_coord_bytelength); // Launch shader for each subtask chunk for offset in (0..num_subtasks as u32).step_by(num_subtask_chunk_size as usize) { @@ -69,7 +66,7 @@ fn smvp_gpu( s_num_z_workgroups, offset, ]; - let params_buf = helper.create_input_buffer(¶ms); + let params_buf = helper.create_buffer(¶ms); let adjusted_s_num_x_workgroups = if num_columns < 256 { s_num_x_workgroups @@ -101,7 +98,6 @@ fn smvp_gpu( &bucket_z_buf, ¶ms_buf, ], - &[], &thread_group_count, &threads_per_group, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/cuzk/transpose.rs b/mopro-msm/src/msm/metal_msm/tests/cuzk/transpose.rs index d70c0012..610bb296 100644 --- a/mopro-msm/src/msm/metal_msm/tests/cuzk/transpose.rs +++ b/mopro-msm/src/msm/metal_msm/tests/cuzk/transpose.rs @@ -37,11 +37,11 @@ fn test_sparse_matrix_transposition() { } // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -56,7 +56,6 @@ fn test_sparse_matrix_transposition() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -114,11 +113,11 @@ fn test_transpose_single_element() { let (expected_col_ptr, expected_val_idxs) = compute_expected_csc(&all_csr_col_idx, n); // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -133,7 +132,6 @@ fn test_transpose_single_element() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -175,11 +173,11 @@ fn test_transpose_all_same_column() { let (expected_col_ptr, expected_val_idxs) = compute_expected_csc(&all_csr_col_idx, n); // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -194,7 +192,6 @@ fn test_transpose_all_same_column() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -236,11 +233,11 @@ fn test_transpose_sequential_columns() { let (expected_col_ptr, expected_val_idxs) = compute_expected_csc(&all_csr_col_idx, n); // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -255,7 +252,6 @@ fn test_transpose_sequential_columns() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -297,11 +293,11 @@ fn test_transpose_reverse_order() { let (expected_col_ptr, expected_val_idxs) = compute_expected_csc(&all_csr_col_idx, n); // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -316,7 +312,6 @@ fn test_transpose_reverse_order() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -358,11 +353,11 @@ fn test_transpose_empty_columns() { let (expected_col_ptr, expected_val_idxs) = compute_expected_csc(&all_csr_col_idx, n); // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -377,7 +372,6 @@ fn test_transpose_empty_columns() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -437,11 +431,11 @@ fn test_transpose_multiple_subtasks_edge_cases() { all_csr_col_idx.extend_from_slice(&cols2); // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -456,7 +450,6 @@ fn test_transpose_multiple_subtasks_edge_cases() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -513,11 +506,11 @@ fn test_transpose_boundary_columns() { let (expected_col_ptr, expected_val_idxs) = compute_expected_csc(&all_csr_col_idx, n); // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -532,7 +525,6 @@ fn test_transpose_boundary_columns() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); @@ -584,11 +576,11 @@ fn test_transpose_large_scale() { } // Create buffers - let all_csr_col_idx_buf = helper.create_input_buffer(&all_csr_col_idx); - let all_csc_col_ptr_buf = helper.create_output_buffer(num_subtasks * (n as usize + 1)); - let all_csc_val_idxs_buf = helper.create_output_buffer(num_subtasks * input_size as usize); - let all_curr_buf = helper.create_output_buffer(num_subtasks * n as usize); - let params_buf = helper.create_input_buffer(&vec![n, input_size]); + let all_csr_col_idx_buf = helper.create_buffer(&all_csr_col_idx); + let all_csc_col_ptr_buf = helper.create_empty_buffer(num_subtasks * (n as usize + 1)); + let all_csc_val_idxs_buf = helper.create_empty_buffer(num_subtasks * input_size as usize); + let all_curr_buf = helper.create_empty_buffer(num_subtasks * n as usize); + let params_buf = helper.create_buffer(&vec![n, input_size]); // Execute shader let thread_group_count = helper.create_thread_group_size(num_subtasks as u64, 1, 1); @@ -603,7 +595,6 @@ fn test_transpose_large_scale() { &all_curr_buf, ¶ms_buf, ], - &[], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs index 17ac1123..50ede69f 100644 --- a/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs @@ -38,9 +38,9 @@ pub fn test_ff_add() { assert!(a < p, "a must be less than p"); assert!(b < p, "b must be less than p"); - let a_buf = helper.create_input_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); - let b_buf = helper.create_input_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); - let result_buf = helper.create_output_buffer(config.num_limbs); + let a_buf = helper.create_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); + let b_buf = helper.create_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); + let result_buf = helper.create_empty_buffer(config.num_limbs); // Calculate expected result: (a + b) % p let mut expected = a.clone(); @@ -70,8 +70,7 @@ pub fn test_ff_add() { helper.execute_shader( &config, - &[&a_buf, &b_buf], - &[&result_buf], + &[&a_buf, &b_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_reduce.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_reduce.rs index 1a424760..469dae82 100644 --- a/mopro-msm/src/msm/metal_msm/tests/field/ff_reduce.rs +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_reduce.rs @@ -19,8 +19,8 @@ pub fn test_ff_reduce_a_less_than_p() { let (a, expected) = generate_test_values(false); - let a_buf = helper.create_input_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); - let result_buf = helper.create_output_buffer(config.num_limbs); + let a_buf = helper.create_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); + let result_buf = helper.create_empty_buffer(config.num_limbs); let expected_limbs = expected.to_limbs(config.num_limbs, config.log_limb_size); @@ -29,8 +29,7 @@ pub fn test_ff_reduce_a_less_than_p() { helper.execute_shader( &config, - &[&a_buf], - &[&result_buf], + &[&a_buf, &result_buf], &thread_group_count, &thread_group_size, ); @@ -58,8 +57,8 @@ pub fn test_ff_reduce_a_greater_than_p_less_than_2p() { let (a, expected) = generate_test_values(true); - let a_buf = helper.create_input_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); - let result_buf = helper.create_output_buffer(config.num_limbs); + let a_buf = helper.create_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); + let result_buf = helper.create_empty_buffer(config.num_limbs); let expected_limbs = expected.to_limbs(config.num_limbs, config.log_limb_size); @@ -68,8 +67,7 @@ pub fn test_ff_reduce_a_greater_than_p_less_than_2p() { helper.execute_shader( &config, - &[&a_buf], - &[&result_buf], + &[&a_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs index f35806c9..ce57ab7c 100644 --- a/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs @@ -38,9 +38,9 @@ pub fn test_ff_sub() { assert!(a < p, "a must be less than p"); assert!(b < p, "b must be less than p"); - let a_buf = helper.create_input_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); - let b_buf = helper.create_input_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); - let result_buf = helper.create_output_buffer(config.num_limbs); + let a_buf = helper.create_buffer(&a.to_limbs(config.num_limbs, config.log_limb_size)); + let b_buf = helper.create_buffer(&b.to_limbs(config.num_limbs, config.log_limb_size)); + let result_buf = helper.create_empty_buffer(config.num_limbs); // (a - b) % p let mut expected = a.clone(); @@ -74,8 +74,7 @@ pub fn test_ff_sub() { helper.execute_shader( &config, - &[&a_buf, &b_buf], - &[&result_buf], + &[&a_buf, &b_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/misc/get_constant.rs b/mopro-msm/src/msm/metal_msm/tests/misc/get_constant.rs index 993f8f9d..647b32e5 100644 --- a/mopro-msm/src/msm/metal_msm/tests/misc/get_constant.rs +++ b/mopro-msm/src/msm/metal_msm/tests/misc/get_constant.rs @@ -22,14 +22,13 @@ pub fn test_get_n0() { }; let mut helper = MetalHelper::new(); - let result_buf = helper.create_output_buffer(config.num_limbs); + let result_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&result_buf], &thread_group_count, &thread_group_size, @@ -58,14 +57,13 @@ pub fn test_get_p() { }; let mut helper = MetalHelper::new(); - let result_buf = helper.create_output_buffer(config.num_limbs); + let result_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&result_buf], &thread_group_count, &thread_group_size, @@ -93,14 +91,13 @@ pub fn test_get_r() { }; let mut helper = MetalHelper::new(); - let result_buf = helper.create_output_buffer(NUM_LIMBS_WIDE); + let result_buf = helper.create_empty_buffer(NUM_LIMBS_WIDE); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&result_buf], &thread_group_count, &thread_group_size, @@ -129,14 +126,13 @@ pub fn test_get_p_wide() { }; let mut helper = MetalHelper::new(); - let result_buf = helper.create_output_buffer(config.num_limbs); + let result_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&result_buf], &thread_group_count, &thread_group_size, @@ -164,16 +160,15 @@ pub fn test_get_bn254_zero() { }; let mut helper = MetalHelper::new(); - let x_buf = helper.create_output_buffer(config.num_limbs); - let y_buf = helper.create_output_buffer(config.num_limbs); - let z_buf = helper.create_output_buffer(config.num_limbs); + let x_buf = helper.create_empty_buffer(config.num_limbs); + let y_buf = helper.create_empty_buffer(config.num_limbs); + let z_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&x_buf, &y_buf, &z_buf], &thread_group_count, &thread_group_size, @@ -204,16 +199,15 @@ pub fn test_get_bn254_one() { }; let mut helper = MetalHelper::new(); - let x_buf = helper.create_output_buffer(config.num_limbs); - let y_buf = helper.create_output_buffer(config.num_limbs); - let z_buf = helper.create_output_buffer(config.num_limbs); + let x_buf = helper.create_empty_buffer(config.num_limbs); + let y_buf = helper.create_empty_buffer(config.num_limbs); + let z_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&x_buf, &y_buf, &z_buf], &thread_group_count, &thread_group_size, @@ -242,16 +236,15 @@ pub fn test_get_bn254_zero_mont() { }; let mut helper = MetalHelper::new(); - let x_buf = helper.create_output_buffer(config.num_limbs); - let y_buf = helper.create_output_buffer(config.num_limbs); - let z_buf = helper.create_output_buffer(config.num_limbs); + let x_buf = helper.create_empty_buffer(config.num_limbs); + let y_buf = helper.create_empty_buffer(config.num_limbs); + let z_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&x_buf, &y_buf, &z_buf], &thread_group_count, &thread_group_size, @@ -292,16 +285,15 @@ pub fn test_get_bn254_one_mont() { }; let mut helper = MetalHelper::new(); - let x_buf = helper.create_output_buffer(config.num_limbs); - let y_buf = helper.create_output_buffer(config.num_limbs); - let z_buf = helper.create_output_buffer(config.num_limbs); + let x_buf = helper.create_empty_buffer(config.num_limbs); + let y_buf = helper.create_empty_buffer(config.num_limbs); + let z_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&x_buf, &y_buf, &z_buf], &thread_group_count, &thread_group_size, @@ -343,14 +335,13 @@ pub fn test_get_mu() { let mut helper = MetalHelper::new(); - let result_buf = helper.create_output_buffer(config.num_limbs); + let result_buf = helper.create_empty_buffer(config.num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[], &[&result_buf], &thread_group_count, &thread_group_size, diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs index 04bcd039..f8af6b86 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs @@ -98,12 +98,10 @@ pub fn benchmark(log_limb_size: u32, shader_file: &str) -> Result { let mut helper = MetalHelper::new(); // Create buffers - let a_buf = - helper.create_input_buffer(&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); - let b_buf = - helper.create_input_buffer(&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); - let cost_buf = helper.create_input_buffer(&vec![cost as u32]); - let result_buf = helper.create_output_buffer(num_limbs); + let a_buf = helper.create_buffer(&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); + let b_buf = helper.create_buffer(&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); + let cost_buf = helper.create_buffer(&vec![cost as u32]); + let result_buf = helper.create_empty_buffer(num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); @@ -113,8 +111,7 @@ pub fn benchmark(log_limb_size: u32, shader_file: &str) -> Result { helper.execute_shader( &config, - &[&a_buf, &b_buf, &cost_buf], - &[&result_buf], + &[&a_buf, &b_buf, &cost_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs index eda259a1..4fcb0858 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs @@ -50,19 +50,16 @@ pub fn do_test(log_limb_size: u32) { .into_bigint() .to_limbs(num_limbs, log_limb_size); - let a_buf = - helper.create_input_buffer(&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); - let b_buf = - helper.create_input_buffer(&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); - let result_buf = helper.create_output_buffer(num_limbs); + let a_buf = helper.create_buffer(&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); + let b_buf = helper.create_buffer(&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); + let result_buf = helper.create_empty_buffer(num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[&a_buf, &b_buf], - &[&result_buf], + &[&a_buf, &b_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs index 42f437cd..f7cec4a2 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs @@ -44,19 +44,16 @@ pub fn do_test(log_limb_size: u32) { .into_bigint() .to_limbs(num_limbs, log_limb_size); - let a_buf = - helper.create_input_buffer(&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); - let b_buf = - helper.create_input_buffer(&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); - let result_buf = helper.create_output_buffer(num_limbs); + let a_buf = helper.create_buffer(&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); + let b_buf = helper.create_buffer(&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); + let result_buf = helper.create_empty_buffer(num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[&a_buf, &b_buf], - &[&result_buf], + &[&a_buf, &b_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs index 045d775d..3ab2bc64 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_optimised.rs @@ -44,19 +44,16 @@ pub fn do_test(log_limb_size: u32) { .into_bigint() .to_limbs(num_limbs, log_limb_size); - let a_buf = - helper.create_input_buffer(&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); - let b_buf = - helper.create_input_buffer(&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); - let result_buf = helper.create_output_buffer(num_limbs); + let a_buf = helper.create_buffer(&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); + let b_buf = helper.create_buffer(&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size)); + let result_buf = helper.create_empty_buffer(num_limbs); let thread_group_count = helper.create_thread_group_size(1, 1, 1); let thread_group_size = helper.create_thread_group_size(1, 1, 1); helper.execute_shader( &config, - &[&a_buf, &b_buf], - &[&result_buf], + &[&a_buf, &b_buf, &result_buf], &thread_group_count, &thread_group_size, ); diff --git a/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs b/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs index 7141a67c..9fbc72b4 100644 --- a/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs +++ b/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs @@ -71,15 +71,15 @@ impl MetalHelper { } } - /// Create an input buffer in Vec and track it - pub fn create_input_buffer(&mut self, data: &Vec) -> Buffer { + /// Create a buffer in Vec and track it + pub fn create_buffer(&mut self, data: &Vec) -> Buffer { let buffer = create_buffer(&self.device, data); self.buffers.push(buffer.clone()); buffer } - /// Create an output buffer and track it - pub fn create_output_buffer(&mut self, size: usize) -> Buffer { + /// Create an empty buffer and track it + pub fn create_empty_buffer(&mut self, size: usize) -> Buffer { let buffer = create_empty_buffer(&self.device, size); self.buffers.push(buffer.clone()); buffer @@ -98,8 +98,7 @@ impl MetalHelper { pub fn execute_shader( &self, config: &MetalConfig, - input_buffers: &[&Buffer], - output_buffers: &[&Buffer], + buffers: &[&Buffer], thread_group_count: &MTLSize, threads_per_threadgroup: &MTLSize, ) { @@ -144,15 +143,11 @@ impl MetalHelper { encoder.set_compute_pipeline_state(&pipeline_state); - // Set input buffers - for (i, buffer) in input_buffers.iter().enumerate() { + // Set buffers + for (i, buffer) in buffers.iter().enumerate() { encoder.set_buffer(i as u64, Some(buffer), 0); } - // Set output buffer - for (i, buffer) in output_buffers.iter().enumerate() { - encoder.set_buffer(input_buffers.len() as u64 + i as u64, Some(buffer), 0); - } encoder.dispatch_thread_groups(*thread_group_count, *threads_per_threadgroup); encoder.end_encoding(); From 9c4a161c3f5e3665199528f010490cbb88317414 Mon Sep 17 00:00:00 2001 From: moven0831 Date: Tue, 3 Jun 2025 13:02:53 +0800 Subject: [PATCH 2/2] refactor(metal_msm): impl ShaderManager for pre-compiled shader handling --- mopro-msm/src/msm/metal_msm/host/shader.rs | 33 ++ mopro-msm/src/msm/metal_msm/metal_msm.rs | 187 ++++----- .../src/msm/metal_msm/utils/metal_wrapper.rs | 102 ++++- mopro-msm/src/msm/metal_msm/utils/mod.rs | 1 + .../src/msm/metal_msm/utils/shader_manager.rs | 354 ++++++++++++++++++ 5 files changed, 580 insertions(+), 97 deletions(-) create mode 100644 mopro-msm/src/msm/metal_msm/utils/shader_manager.rs diff --git a/mopro-msm/src/msm/metal_msm/host/shader.rs b/mopro-msm/src/msm/metal_msm/host/shader.rs index d6878115..e63b1153 100644 --- a/mopro-msm/src/msm/metal_msm/host/shader.rs +++ b/mopro-msm/src/msm/metal_msm/host/shader.rs @@ -36,6 +36,39 @@ macro_rules! write_constant_array { }; } +/// Get the shader directory path using CARGO_MANIFEST_DIR for proper OS file location +/// This function provides robust path resolution that works across different build environments +pub fn get_shader_dir() -> PathBuf { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let shader_path = manifest_dir + .join("src") + .join("msm") + .join("metal_msm") + .join("shader"); + + // Verify the path exists, if not try alternative locations + if shader_path.exists() { + shader_path + } else { + // Fallback: try relative path from workspace root + let workspace_shader_path = manifest_dir + .parent() // go up one level from mopro-msm to workspace root + .unwrap_or(&manifest_dir) + .join("mopro-msm") + .join("src") + .join("msm") + .join("metal_msm") + .join("shader"); + + if workspace_shader_path.exists() { + workspace_shader_path + } else { + // Final fallback: use the original path and let error handling take care of it + shader_path + } + } +} + pub fn compile_metal(path_from_cargo_manifest_dir: &str, input_filename: &str) -> String { let input_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join(path_from_cargo_manifest_dir) diff --git a/mopro-msm/src/msm/metal_msm/metal_msm.rs b/mopro-msm/src/msm/metal_msm/metal_msm.rs index c6fc0503..ec59188e 100644 --- a/mopro-msm/src/msm/metal_msm/metal_msm.rs +++ b/mopro-msm/src/msm/metal_msm/metal_msm.rs @@ -3,6 +3,9 @@ use crate::msm::metal_msm::utils::limbs_conversion::{ }; use crate::msm::metal_msm::utils::metal_wrapper::{MetalConfig, MetalHelper}; use crate::msm::metal_msm::utils::mont_reduction::raw_reduction; +use crate::msm::metal_msm::utils::shader_manager::{ + ShaderManager, ShaderManagerConfig, ShaderType, +}; use ark_bn254::{Fq as BaseField, Fr as ScalarField, G1Affine as Affine, G1Projective as G}; use ark_ff::{BigInt, PrimeField}; use ark_std::{vec::Vec, Zero}; @@ -25,17 +28,33 @@ impl Default for MetalMSMConfig { } } -/// Main Metal MSM pipeline -struct MetalMSMPipeline { +impl From for ShaderManagerConfig { + fn from(config: MetalMSMConfig) -> Self { + Self { + num_limbs: config.num_limbs, + log_limb_size: config.log_limb_size, + } + } +} + +/// Main Metal MSM pipeline with pre-compiled shaders +pub struct MetalMSMPipeline { config: MetalMSMConfig, + shader_manager: ShaderManager, } impl MetalMSMPipeline { - fn new(config: MetalMSMConfig) -> Self { - Self { config } + fn new(config: MetalMSMConfig) -> Result> { + let shader_config: ShaderManagerConfig = config.clone().into(); + let shader_manager = ShaderManager::new(shader_config)?; + + Ok(Self { + config, + shader_manager, + }) } - fn with_default_config() -> Self { + fn with_default_config() -> Result> { Self::new(MetalMSMConfig::default()) } @@ -55,9 +74,7 @@ impl MetalMSMPipeline { let (coords, scals) = pack_affine_and_scalars(bases, scalars, &metal_config); // Stage 1: Convert Point & Scalar Decomposition - // 1. Unpack point coordinates and encode them into Montgomery form - // 2. Decompose scalars into Signed wNAF form - let stage1 = ConvertPointAndScalarDecompose::new(&self.config); + let stage1 = ConvertPointAndScalarDecompose::new(&self.shader_manager); let mut c_workgroup_size = 256; let mut c_num_x_workgroups = 64; let mut c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups; @@ -98,7 +115,7 @@ impl MetalMSMPipeline { let t_num_z_workgroups = 1; let t_workgroup_size = num_subtasks; - let stage2 = Transpose::new(&self.config); + let stage2 = Transpose::new(&self.shader_manager); let (csc_col_ptr, csc_val_idxs) = stage2.execute( &scalar_chunks, num_subtasks, @@ -116,7 +133,7 @@ impl MetalMSMPipeline { let s_num_y_workgroups = 2; let s_num_z_workgroups = num_subtasks; - let stage3 = SMVP::new(&self.config); + let stage3 = SMVP::new(&self.shader_manager); let (bucket_x, bucket_y, bucket_z) = stage3.execute( &csc_col_ptr, &csc_val_idxs, @@ -132,7 +149,6 @@ impl MetalMSMPipeline { )?; // Stage 4: Parallel Bucket Reduction - // dynamic variable that determines the number of CSR matrices processed per invocation of the BPR shader. A safe default is 1. let num_subtasks_per_bpr_1 = 16; let num_subtasks_per_bpr_2 = 16; @@ -145,7 +161,7 @@ impl MetalMSMPipeline { let b_2_num_y_workgroups = 1; let b_2_num_z_workgroups = 1; - let stage4 = PBPR::new(&self.config); + let stage4 = PBPR::new(&self.shader_manager); let (g_points_x, g_points_y, g_points_z) = stage4.execute( &bucket_x, &bucket_y, @@ -237,20 +253,13 @@ impl MetalMSMPipeline { } /// Stage 1: Convert & Decompose -struct ConvertPointAndScalarDecompose { - config: MetalConfig, +struct ConvertPointAndScalarDecompose<'a> { + shader_manager: &'a ShaderManager, } -impl ConvertPointAndScalarDecompose { - fn new(msm_config: &MetalMSMConfig) -> Self { - Self { - config: MetalConfig { - log_limb_size: msm_config.log_limb_size, - num_limbs: msm_config.num_limbs, - shader_file: "cuzk/convert_point_coords_and_decompose_scalars.metal".to_string(), - kernel_name: "convert_point_coords_and_decompose_scalars".to_string(), - }, - } +impl<'a> ConvertPointAndScalarDecompose<'a> { + fn new(shader_manager: &'a ShaderManager) -> Self { + Self { shader_manager } } fn execute( @@ -264,13 +273,19 @@ impl ConvertPointAndScalarDecompose { c_num_z_workgroups: usize, c_workgroup_size: usize, ) -> Result<(Vec, Vec, Vec), Box> { - let mut helper = MetalHelper::new(); + let mut helper = MetalHelper::with_device(self.shader_manager.device().clone()); + let shader = self + .shader_manager + .get_shader(&ShaderType::ConvertPointAndDecompose) + .ok_or("ConvertPointAndDecompose shader not found")?; let coords_buf = helper.create_buffer(&coords.to_vec()); let scalars_buf = helper.create_buffer(&scalars.to_vec()); - let out_point_x = helper.create_empty_buffer(input_size * self.config.num_limbs); - let out_point_y = helper.create_empty_buffer(input_size * self.config.num_limbs); + let out_point_x = + helper.create_empty_buffer(input_size * self.shader_manager.config().num_limbs); + let out_point_y = + helper.create_empty_buffer(input_size * self.shader_manager.config().num_limbs); let out_scalar_chunks = helper.create_empty_buffer(input_size * num_subtasks); let input_size_buf = helper.create_buffer(&vec![input_size as u32]); @@ -284,8 +299,8 @@ impl ConvertPointAndScalarDecompose { let threads_per_threadgroup = helper.create_thread_group_size(c_workgroup_size as u64, 1, 1); - helper.execute_shader( - &self.config, + helper.execute_shader_with_pipeline( + &shader.pipeline_state, &[ &coords_buf, &scalars_buf, @@ -299,8 +314,14 @@ impl ConvertPointAndScalarDecompose { &threads_per_threadgroup, ); - let point_x = helper.read_results(&out_point_x, input_size * self.config.num_limbs); - let point_y = helper.read_results(&out_point_y, input_size * self.config.num_limbs); + let point_x = helper.read_results( + &out_point_x, + input_size * self.shader_manager.config().num_limbs, + ); + let point_y = helper.read_results( + &out_point_y, + input_size * self.shader_manager.config().num_limbs, + ); let scalar_chunks = helper.read_results(&out_scalar_chunks, input_size * num_subtasks); helper.drop_all_buffers(); @@ -310,20 +331,13 @@ impl ConvertPointAndScalarDecompose { } /// Stage 2: Transpose -struct Transpose { - config: MetalConfig, +struct Transpose<'a> { + shader_manager: &'a ShaderManager, } -impl Transpose { - fn new(msm_config: &MetalMSMConfig) -> Self { - Self { - config: MetalConfig { - log_limb_size: msm_config.log_limb_size, - num_limbs: msm_config.num_limbs, - shader_file: "cuzk/transpose.metal".to_string(), - kernel_name: "transpose".to_string(), - }, - } +impl<'a> Transpose<'a> { + fn new(shader_manager: &'a ShaderManager) -> Self { + Self { shader_manager } } fn execute( @@ -337,7 +351,11 @@ impl Transpose { t_num_z_workgroups: usize, t_workgroup_size: usize, ) -> Result<(Vec, Vec), Box> { - let mut helper = MetalHelper::new(); + let mut helper = MetalHelper::with_device(self.shader_manager.device().clone()); + let shader = self + .shader_manager + .get_shader(&ShaderType::Transpose) + .ok_or("Transpose shader not found")?; let in_chunks_buf = helper.create_buffer(&scalar_chunks.to_vec()); let out_csc_col_ptr = @@ -356,8 +374,8 @@ impl Transpose { let threads_per_threadgroup = helper.create_thread_group_size(t_workgroup_size as u64, 1, 1); - helper.execute_shader( - &self.config, + helper.execute_shader_with_pipeline( + &shader.pipeline_state, &[ &in_chunks_buf, &out_csc_col_ptr, @@ -382,20 +400,13 @@ impl Transpose { } /// Stage 3: SMVP -struct SMVP { - config: MetalConfig, +struct SMVP<'a> { + shader_manager: &'a ShaderManager, } -impl SMVP { - fn new(msm_config: &MetalMSMConfig) -> Self { - Self { - config: MetalConfig { - log_limb_size: msm_config.log_limb_size, - num_limbs: msm_config.num_limbs, - shader_file: "cuzk/smvp.metal".to_string(), - kernel_name: "smvp".to_string(), - }, - } +impl<'a> SMVP<'a> { + fn new(shader_manager: &'a ShaderManager) -> Self { + Self { shader_manager } } fn execute( @@ -412,9 +423,14 @@ impl SMVP { s_num_z_workgroups: usize, s_workgroup_size: usize, ) -> Result<(Vec, Vec, Vec), Box> { - let mut helper = MetalHelper::new(); + let mut helper = MetalHelper::with_device(self.shader_manager.device().clone()); + let shader = self + .shader_manager + .get_shader(&ShaderType::SMVP) + .ok_or("SMVP shader not found")?; - let bucket_size = (num_columns / 2) as usize * self.config.num_limbs * 4 * num_subtasks; + let bucket_size = + (num_columns / 2) as usize * self.shader_manager.config().num_limbs * 4 * num_subtasks; // Create buffers let row_ptr_buf = helper.create_buffer(&csc_col_ptr.to_vec()); @@ -448,8 +464,8 @@ impl SMVP { let threads_per_threadgroup = helper.create_thread_group_size(s_workgroup_size as u64, 1, 1); - helper.execute_shader( - &self.config, + helper.execute_shader_with_pipeline( + &shader.pipeline_state, &[ &row_ptr_buf, &val_idx_buf, @@ -476,27 +492,13 @@ impl SMVP { } /// Stage 4: PBPR -struct PBPR { - stage1_config: MetalConfig, - stage2_config: MetalConfig, +struct PBPR<'a> { + shader_manager: &'a ShaderManager, } -impl PBPR { - fn new(msm_config: &MetalMSMConfig) -> Self { - Self { - stage1_config: MetalConfig { - log_limb_size: msm_config.log_limb_size, - num_limbs: msm_config.num_limbs, - shader_file: "cuzk/pbpr.metal".to_string(), - kernel_name: "bpr_stage_1".to_string(), - }, - stage2_config: MetalConfig { - log_limb_size: msm_config.log_limb_size, - num_limbs: msm_config.num_limbs, - shader_file: "cuzk/pbpr.metal".to_string(), - kernel_name: "bpr_stage_2".to_string(), - }, - } +impl<'a> PBPR<'a> { + fn new(shader_manager: &'a ShaderManager) -> Self { + Self { shader_manager } } fn execute( @@ -516,13 +518,22 @@ impl PBPR { b_2_num_z_workgroups: usize, b_workgroup_size: usize, ) -> Result<(Vec, Vec, Vec), Box> { - let mut helper = MetalHelper::new(); + let mut helper = MetalHelper::with_device(self.shader_manager.device().clone()); + let stage1_shader = self + .shader_manager + .get_shader(&ShaderType::BPRStage1) + .ok_or("BPRStage1 shader not found")?; + let stage2_shader = self + .shader_manager + .get_shader(&ShaderType::BPRStage2) + .ok_or("BPRStage2 shader not found")?; let bucket_sum_x_buf = helper.create_buffer(&bucket_x.to_vec()); let bucket_sum_y_buf = helper.create_buffer(&bucket_y.to_vec()); let bucket_sum_z_buf = helper.create_buffer(&bucket_z.to_vec()); - let g_points_size = num_subtasks * b_workgroup_size * self.stage1_config.num_limbs * 4; + let g_points_size = + num_subtasks * b_workgroup_size * self.shader_manager.config().num_limbs * 4; let g_points_x_buf = helper.create_empty_buffer(g_points_size); let g_points_y_buf = helper.create_empty_buffer(g_points_size); let g_points_z_buf = helper.create_empty_buffer(g_points_size); @@ -545,8 +556,8 @@ impl PBPR { let stage1_threads_per_threadgroup = helper.create_thread_group_size(b_workgroup_size as u64, 1, 1); - helper.execute_shader( - &self.stage1_config, + helper.execute_shader_with_pipeline( + &stage1_shader.pipeline_state, &[ &bucket_sum_x_buf, &bucket_sum_y_buf, @@ -580,8 +591,8 @@ impl PBPR { let stage2_threads_per_threadgroup = helper.create_thread_group_size(b_workgroup_size as u64, 1, 1); - helper.execute_shader( - &self.stage2_config, + helper.execute_shader_with_pipeline( + &stage2_shader.pipeline_state, &[ &bucket_sum_x_buf, &bucket_sum_y_buf, @@ -623,7 +634,7 @@ pub fn metal_variable_base_msm( return Err("Bases and scalars must have the same length".into()); } - let pipeline = MetalMSMPipeline::with_default_config(); + let pipeline = MetalMSMPipeline::with_default_config()?; pipeline.execute(bases, scalars) } diff --git a/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs b/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs index 9fbc72b4..73b43379 100644 --- a/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs +++ b/mopro-msm/src/msm/metal_msm/utils/metal_wrapper.rs @@ -1,7 +1,7 @@ use crate::msm::metal_msm::host::gpu::{ create_buffer, create_empty_buffer, get_default_device, read_buffer, }; -use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; +use crate::msm::metal_msm::host::shader::{compile_metal, get_shader_dir, write_constants}; use crate::msm::metal_msm::utils::barrett_params::calc_barrett_mu; use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0}; @@ -14,14 +14,12 @@ use once_cell::sync::Lazy; use std::collections::HashMap; use std::sync::Mutex; -/// Directory containing the Metal shader files -const SHADER_DIR: &str = "../mopro-msm/src/msm/metal_msm/shader"; - /// Cache of precomputed constants static CONSTANTS_CACHE: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); /// Struct for Metal config +#[derive(Clone)] pub struct MetalConfig { pub log_limb_size: u32, pub num_limbs: usize, @@ -51,7 +49,7 @@ impl Default for MetalConfig { } } -/// Helper to setup Metal device, buffers, and execute shader +/// Enhanced Metal helper that works with pre-compiled shaders pub struct MetalHelper { pub device: Device, pub command_queue: CommandQueue, @@ -59,7 +57,7 @@ pub struct MetalHelper { } impl MetalHelper { - /// Create a new Metal helper + /// Create a new Metal helper with specific device pub fn new() -> Self { let device = get_default_device(); let command_queue = device.new_command_queue(); @@ -71,6 +69,17 @@ impl MetalHelper { } } + /// Create a new Metal helper with custom device + pub fn with_device(device: Device) -> Self { + let command_queue = device.new_command_queue(); + + Self { + device, + command_queue, + buffers: Vec::new(), + } + } + /// Create a buffer in Vec and track it pub fn create_buffer(&mut self, data: &Vec) -> Buffer { let buffer = create_buffer(&self.device, data); @@ -94,7 +103,34 @@ impl MetalHelper { } } - /// Execute a Metal compute shader + /// Execute a Metal compute shader with pre-compiled pipeline state + pub fn execute_shader_with_pipeline( + &self, + pipeline_state: &ComputePipelineState, + buffers: &[&Buffer], + thread_group_count: &MTLSize, + threads_per_threadgroup: &MTLSize, + ) { + let command_buffer = self.command_queue.new_command_buffer(); + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = + command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + encoder.set_compute_pipeline_state(&pipeline_state); + + // Set buffers + for (i, buffer) in buffers.iter().enumerate() { + encoder.set_buffer(i as u64, Some(buffer), 0); + } + + encoder.dispatch_thread_groups(*thread_group_count, *threads_per_threadgroup); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + } + + /// Execute a Metal compute shader (legacy method - kept for compatibility) pub fn execute_shader( &self, config: &MetalConfig, @@ -110,7 +146,7 @@ impl MetalHelper { // Setup shader constants let constants = get_or_calc_constants(config.num_limbs, config.log_limb_size); write_constants( - SHADER_DIR, + get_shader_dir().to_str().unwrap(), config.num_limbs, config.log_limb_size, constants.n0, @@ -118,7 +154,11 @@ impl MetalHelper { ); // Prepare full shader path - let shader_path = format!("{}/{}", SHADER_DIR, config.shader_file); + let shader_path = format!( + "{}/{}", + get_shader_dir().to_str().unwrap(), + config.shader_file + ); let parts: Vec<&str> = shader_path.rsplitn(2, '/').collect(); let shader_dir = if parts.len() > 1 { parts[1] } else { "" }; let shader_file = parts[0]; @@ -164,6 +204,50 @@ impl MetalHelper { pub fn drop_all_buffers(&mut self) { self.buffers.clear(); } + + /// Get device reference + pub fn device(&self) -> &Device { + &self.device + } + + /// Get command queue reference + pub fn command_queue(&self) -> &CommandQueue { + &self.command_queue + } + + /// Execute multiple compute shaders in sequence with shared command buffer + pub fn execute_shaders_batch(&self, operations: Vec) { + let command_buffer = self.command_queue.new_command_buffer(); + + for operation in operations { + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = + command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + encoder.set_compute_pipeline_state(&operation.pipeline_state); + + for (i, buffer) in operation.buffers.iter().enumerate() { + encoder.set_buffer(i as u64, Some(buffer), 0); + } + + encoder.dispatch_thread_groups( + operation.thread_group_count, + operation.threads_per_threadgroup, + ); + encoder.end_encoding(); + } + + command_buffer.commit(); + command_buffer.wait_until_completed(); + } +} + +/// Represents a single shader operation for batch execution +pub struct ShaderOperation<'a> { + pub pipeline_state: ComputePipelineState, + pub buffers: Vec<&'a Buffer>, + pub thread_group_count: MTLSize, + pub threads_per_threadgroup: MTLSize, } // Calculate or retrieve cached constants diff --git a/mopro-msm/src/msm/metal_msm/utils/mod.rs b/mopro-msm/src/msm/metal_msm/utils/mod.rs index 2ff47402..7b7c1866 100644 --- a/mopro-msm/src/msm/metal_msm/utils/mod.rs +++ b/mopro-msm/src/msm/metal_msm/utils/mod.rs @@ -3,3 +3,4 @@ pub mod limbs_conversion; pub mod metal_wrapper; pub mod mont_params; pub mod mont_reduction; +pub mod shader_manager; diff --git a/mopro-msm/src/msm/metal_msm/utils/shader_manager.rs b/mopro-msm/src/msm/metal_msm/utils/shader_manager.rs new file mode 100644 index 00000000..10c91591 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/utils/shader_manager.rs @@ -0,0 +1,354 @@ +use crate::msm::metal_msm::host::gpu::get_default_device; +use crate::msm::metal_msm::host::shader::{compile_metal, get_shader_dir, write_constants}; +use crate::msm::metal_msm::utils::metal_wrapper::{ + get_or_calc_constants, MSMConstants, MetalConfig, +}; +use metal::*; +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Mutex; + +/// Cache of compiled pipeline states +static PIPELINE_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +/// Shader types used in MSM pipeline +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ShaderType { + ConvertPointAndDecompose, + Transpose, + SMVP, + BPRStage1, + BPRStage2, +} + +impl ShaderType { + pub fn get_config(&self, num_limbs: usize, log_limb_size: u32) -> MetalConfig { + match self { + ShaderType::ConvertPointAndDecompose => MetalConfig { + log_limb_size, + num_limbs, + shader_file: "cuzk/convert_point_coords_and_decompose_scalars.metal".to_string(), + kernel_name: "convert_point_coords_and_decompose_scalars".to_string(), + }, + ShaderType::Transpose => MetalConfig { + log_limb_size, + num_limbs, + shader_file: "cuzk/transpose.metal".to_string(), + kernel_name: "transpose".to_string(), + }, + ShaderType::SMVP => MetalConfig { + log_limb_size, + num_limbs, + shader_file: "cuzk/smvp.metal".to_string(), + kernel_name: "smvp".to_string(), + }, + ShaderType::BPRStage1 => MetalConfig { + log_limb_size, + num_limbs, + shader_file: "cuzk/pbpr.metal".to_string(), + kernel_name: "bpr_stage_1".to_string(), + }, + ShaderType::BPRStage2 => MetalConfig { + log_limb_size, + num_limbs, + shader_file: "cuzk/pbpr.metal".to_string(), + kernel_name: "bpr_stage_2".to_string(), + }, + } + } +} + +/// Configuration for shader manager +#[derive(Clone, Debug)] +pub struct ShaderManagerConfig { + pub num_limbs: usize, + pub log_limb_size: u32, +} + +impl Default for ShaderManagerConfig { + fn default() -> Self { + Self { + num_limbs: 16, + log_limb_size: 16, + } + } +} + +/// Pre-compiled shader information +#[derive(Clone)] +pub struct PrecompiledShader { + pub pipeline_state: ComputePipelineState, + pub config: MetalConfig, + pub constants: MSMConstants, +} + +/// Main shader manager that handles pre-compilation and caching +#[derive(Clone)] +pub struct ShaderManager { + device: Device, + config: ShaderManagerConfig, + shaders: HashMap, + constants: MSMConstants, +} + +impl ShaderManager { + /// Create a new shader manager with the given configuration + pub fn new(config: ShaderManagerConfig) -> Result> { + let device = get_default_device(); + let constants = get_or_calc_constants(config.num_limbs, config.log_limb_size); + + // Pre-write constants to avoid doing it on every shader execution + write_constants( + get_shader_dir().to_str().unwrap(), + config.num_limbs, + config.log_limb_size, + constants.n0, + constants.nsafe, + ); + + let mut manager = Self { + device, + config: config.clone(), + shaders: HashMap::new(), + constants, + }; + + // Pre-compile all shaders + manager.compile_all_shaders()?; + + Ok(manager) + } + + /// Create a shader manager with default configuration + pub fn with_default_config() -> Result> { + Self::new(ShaderManagerConfig::default()) + } + + /// Get the shader directory path (useful for debugging and external tools) + pub fn get_shader_directory() -> PathBuf { + get_shader_dir() + } + + /// Pre-compile all shaders used in the MSM pipeline + fn compile_all_shaders(&mut self) -> Result<(), Box> { + let shader_types = vec![ + ShaderType::ConvertPointAndDecompose, + ShaderType::Transpose, + ShaderType::SMVP, + ShaderType::BPRStage1, + ShaderType::BPRStage2, + ]; + + for shader_type in shader_types { + let precompiled = self.compile_shader(&shader_type)?; + self.shaders.insert(shader_type, precompiled); + } + + Ok(()) + } + + /// Compile a single shader and return precompiled information + fn compile_shader( + &self, + shader_type: &ShaderType, + ) -> Result> { + let config = shader_type.get_config(self.config.num_limbs, self.config.log_limb_size); + let cache_key = format!( + "{}_{}_{}_{}", + config.shader_file, + config.kernel_name, + self.config.num_limbs, + self.config.log_limb_size + ); + + // Check if already cached + { + let cache = PIPELINE_CACHE.lock().unwrap(); + if let Some(pipeline_state) = cache.get(&cache_key) { + return Ok(PrecompiledShader { + pipeline_state: pipeline_state.clone(), + config, + constants: self.constants.clone(), + }); + } + } + + // Compile shader + let shader_path = format!( + "{}/{}", + get_shader_dir().to_str().unwrap(), + config.shader_file + ); + let parts: Vec<&str> = shader_path.rsplitn(2, '/').collect(); + let shader_dir = if parts.len() > 1 { parts[1] } else { "" }; + let shader_file = parts[0]; + + let library_path = compile_metal(shader_dir, shader_file); + let library = self + .device + .new_library_with_file(library_path) + .map_err(|e| format!("Failed to create library: {:?}", e))?; + + let kernel = library + .get_function(config.kernel_name.as_str(), None) + .map_err(|e| { + format!( + "Failed to get kernel function {}: {:?}", + config.kernel_name, e + ) + })?; + + let pipeline_state = self + .device + .new_compute_pipeline_state_with_function(&kernel) + .map_err(|e| format!("Failed to create pipeline state: {:?}", e))?; + + // Cache the pipeline state + { + let mut cache = PIPELINE_CACHE.lock().unwrap(); + cache.insert(cache_key, pipeline_state.clone()); + } + + Ok(PrecompiledShader { + pipeline_state, + config, + constants: self.constants.clone(), + }) + } + + /// Get a precompiled shader by type + pub fn get_shader(&self, shader_type: &ShaderType) -> Option<&PrecompiledShader> { + self.shaders.get(shader_type) + } + + /// Get the Metal device + pub fn device(&self) -> &Device { + &self.device + } + + /// Get the configuration + pub fn config(&self) -> &ShaderManagerConfig { + &self.config + } + + /// Get the constants + pub fn constants(&self) -> &MSMConstants { + &self.constants + } + + /// Update configuration and recompile shaders if needed + pub fn update_config( + &mut self, + new_config: ShaderManagerConfig, + ) -> Result<(), Box> { + if self.config.num_limbs != new_config.num_limbs + || self.config.log_limb_size != new_config.log_limb_size + { + self.config = new_config; + self.constants = + get_or_calc_constants(self.config.num_limbs, self.config.log_limb_size); + + // Re-write constants + write_constants( + get_shader_dir().to_str().unwrap(), + self.config.num_limbs, + self.config.log_limb_size, + self.constants.n0, + self.constants.nsafe, + ); + + // Re-compile all shaders + self.shaders.clear(); + self.compile_all_shaders()?; + } + Ok(()) + } + + /// Clear the pipeline cache (useful for development/testing) + pub fn clear_cache() { + let mut cache = PIPELINE_CACHE.lock().unwrap(); + cache.clear(); + } +} + +/// Builder pattern for creating shader manager with custom configuration +pub struct ShaderManagerBuilder { + num_limbs: Option, + log_limb_size: Option, +} + +impl Default for ShaderManagerBuilder { + fn default() -> Self { + Self { + num_limbs: None, + log_limb_size: None, + } + } +} + +impl ShaderManagerBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn num_limbs(mut self, num_limbs: usize) -> Self { + self.num_limbs = Some(num_limbs); + self + } + + pub fn log_limb_size(mut self, log_limb_size: u32) -> Self { + self.log_limb_size = Some(log_limb_size); + self + } + + pub fn build(self) -> Result> { + let config = ShaderManagerConfig { + num_limbs: self.num_limbs.unwrap_or(16), + log_limb_size: self.log_limb_size.unwrap_or(16), + }; + ShaderManager::new(config) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[serial_test::serial] + fn test_shader_manager_creation() { + let manager = ShaderManager::with_default_config().unwrap(); + assert_eq!(manager.config().num_limbs, 16); + assert_eq!(manager.config().log_limb_size, 16); + } + + #[test] + #[serial_test::serial] + fn test_shader_manager_builder() { + let manager = ShaderManagerBuilder::new() + .num_limbs(16) + .log_limb_size(16) + .build() + .unwrap(); + + assert_eq!(manager.config().num_limbs, 16); + assert_eq!(manager.config().log_limb_size, 16); + } + + #[test] + #[serial_test::serial] + fn test_get_precompiled_shaders() { + let manager = ShaderManager::with_default_config().unwrap(); + + let convert_shader = manager.get_shader(&ShaderType::ConvertPointAndDecompose); + assert!(convert_shader.is_some()); + + let transpose_shader = manager.get_shader(&ShaderType::Transpose); + assert!(transpose_shader.is_some()); + + let smvp_shader = manager.get_shader(&ShaderType::SMVP); + assert!(smvp_shader.is_some()); + } +}