Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 99 additions & 98 deletions mopro-msm/build.rs
Original file line number Diff line number Diff line change
@@ -1,116 +1,117 @@
use std::{env, path::Path, process::Command};

fn main() {
compile_shaders();
use std::env;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::process::Command;

fn main() -> std::io::Result<()> {
compile_shaders()?;
Ok(())
}

fn compile_shaders() {
let shader_dir = "src/msm/metal/shader/";
let out_dir = env::var("OUT_DIR").unwrap();

// List your Metal shaders here.
let shaders = vec!["all.metal"];

let shaders_to_check = vec!["all.metal", "msm.h.metal"];

let mut air_files = vec![];

// Step 1: Compile every shader to AIR format
for shader in &shaders {
let shader_path = Path::new(shader_dir).join(shader);
let air_output = Path::new(&out_dir).join(format!("{}.air", shader));

let mut args = vec![
fn compile_shaders() -> std::io::Result<()> {
let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let shader_root = manifest_dir
.join("src")
.join("msm")
.join("metal_msm")
.join("shader")
.join("cuzk");
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let shader_out_dir = out_dir.join("shaders");
fs::create_dir_all(&shader_out_dir)?;

// Gather all .metal under /shader/cuzk
let mut metal_paths = Vec::new();
visit_dirs(&shader_root, &mut metal_paths)?;
// Filter only the desired kernels
metal_paths.retain(|path| {
path.file_name()
.and_then(|n| n.to_str())
.map(|name| {
name.starts_with("convert_point")
|| name.starts_with("barrett_reduction")
|| name.starts_with("pbpr")
|| name.starts_with("smvp")
|| name.starts_with("transpose")
})
.unwrap_or(false)
});

// Combine selected kernels into one .metal file
let combined = shader_out_dir.join("msm_combined.metal");
let mut combined_src = String::new();
combined_src.push_str("#include <metal_stdlib>\n#include <metal_math>\n");
for path in &metal_paths {
let inc = path.to_str().unwrap();
combined_src.push_str(&format!("#include \"{}\"\n", inc));
println!("cargo:rerun-if-changed={}", inc);
}
fs::write(&combined, &combined_src)?;

// Compile combined source to AIR
let target = env::var("TARGET").unwrap_or_default();
let sdk = if target.contains("apple-ios") {
"iphoneos"
} else {
"macosx"
};
let air = shader_out_dir.join("msm.air");
let status = Command::new("xcrun")
.args(&[
"-sdk",
get_sdk(),
sdk,
"metal",
"-c",
shader_path.to_str().unwrap(),
combined.to_str().unwrap(),
"-o",
air_output.to_str().unwrap(),
];

if cfg!(feature = "profiling-release") {
args.push("-frecord-sources");
}

// Compile shader into .air files
let status = Command::new("xcrun")
.args(&args)
.status()
.expect("Shader compilation failed");

if !status.success() {
panic!("Shader compilation failed for {}", shader);
}

air_files.push(air_output);
}

// Step 2: Link all the .air files into a Metallib archive
let metallib_output = Path::new(&out_dir).join("msm.metallib");

let mut metallib_args = vec![
"-sdk",
get_sdk(),
"metal",
"-o",
metallib_output.to_str().unwrap(),
];

if cfg!(feature = "profiling-release") {
metallib_args.push("-frecord-sources");
}

for air_file in &air_files {
metallib_args.push(air_file.to_str().unwrap());
}

let status = Command::new("xcrun")
.args(&metallib_args)
air.to_str().unwrap(),
])
.status()
.expect("Failed to link shaders into metallib");

.expect("Failed to invoke metal");
if !status.success() {
panic!("Failed to link shaders into metallib");
panic!("Metal compile failed");
}
// We now have single .air; proceed to link
let air_path = air.to_str().unwrap();

let symbols_args = vec![
"metal-dsymutil",
"-flat",
"-remove-source",
metallib_output.to_str().unwrap(),
];

// Link AIR into msm.metallib
let msm_lib = shader_out_dir.join("msm.metallib");
let status = Command::new("xcrun")
.args(&symbols_args)
.args(&[
"-sdk",
sdk,
"metallib",
air_path,
"-o",
msm_lib.to_str().unwrap(),
])
.status()
.expect("Failed to extract symbols");

.expect("Failed to invoke metallib");
if !status.success() {
panic!("Failed to extract symbols");
}

// Inform cargo to watch all shader files for changes
for shader in &shaders_to_check {
let shader_path = Path::new(shader_dir).join(shader);
println!("cargo:rerun-if-changed={}", shader_path.to_str().unwrap());
panic!("Metallib linking failed");
}
}

#[cfg(feature = "macos")]
fn get_sdk() -> &'static str {
"macosx"
// Emit a single built_shaders.rs with embedded msm.metallib
let dest = out_dir.join("built_shaders.rs");
let mut f = fs::File::create(&dest)?;
writeln!(
f,
"pub const MSM_METALLIB: &[u8] = include_bytes!(concat!(env!(\"OUT_DIR\"), \"/shaders/msm.metallib\"));"
)?;
Ok(())
}

#[cfg(not(feature = "macos"))]
#[cfg(feature = "ios")]
fn get_sdk() -> &'static str {
"iphoneos"
}

#[cfg(not(feature = "macos"))]
#[cfg(not(feature = "ios"))]
fn get_sdk() -> &'static str {
panic!("one of the features macos or ios needs to be enabled");
fn visit_dirs(dir: &Path, paths: &mut Vec<PathBuf>) -> std::io::Result<()> {
if dir.is_dir() {
for entry in fs::read_dir(dir)? {
let p = entry?.path();
if p.is_dir() {
visit_dirs(&p, paths)?;
} else if p.extension().and_then(|e| e.to_str()) == Some("metal") {
paths.push(p);
}
}
}
Ok(())
}
114 changes: 0 additions & 114 deletions mopro-msm/src/msm/metal_msm/host/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ use ark_ff::{BigInt, PrimeField, Zero};
use num_bigint::BigUint;
use std::fs;
use std::path::PathBuf;
use std::process::Command;
use std::string::String;

macro_rules! write_constant_array {
($data:expr, $name:expr, $values:expr, $size:expr) => {
Expand All @@ -36,108 +34,6 @@ macro_rules! write_constant_array {
};
}

/// Get the shader directory path using CARGO_MANIFEST_DIR for proper OS file location
/// This function provides robust path resolution that works across different build environments
pub fn get_shader_dir() -> PathBuf {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let shader_path = manifest_dir
.join("src")
.join("msm")
.join("metal_msm")
.join("shader");

// Verify the path exists, if not try alternative locations
if shader_path.exists() {
shader_path
} else {
// Fallback: try relative path from workspace root
let workspace_shader_path = manifest_dir
.parent() // go up one level from mopro-msm to workspace root
.unwrap_or(&manifest_dir)
.join("mopro-msm")
.join("src")
.join("msm")
.join("metal_msm")
.join("shader");

if workspace_shader_path.exists() {
workspace_shader_path
} else {
// Final fallback: use the original path and let error handling take care of it
shader_path
}
}
}

pub fn compile_metal(path_from_cargo_manifest_dir: &str, input_filename: &str) -> String {
let input_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join(path_from_cargo_manifest_dir)
.join(input_filename);
let c = input_path.clone().into_os_string().into_string().unwrap();

let lib = input_path.clone().into_os_string().into_string().unwrap();
let lib = format!("{}.lib", lib);

let exe = if cfg!(target_os = "ios") {
Command::new("xcrun")
.args([
"-sdk",
"iphoneos",
"metal",
"-std=metal3.2",
"-target",
"air64-apple-ios18.0",
"-fmetal-enable-logging",
"-o",
lib.as_str(),
c.as_str(),
])
.output()
.expect("failed to compile")
} else if cfg!(target_os = "macos") {
let macos_version = std::process::Command::new("sw_vers")
.args(["-productVersion"])
.output()
.ok()
.and_then(|output| String::from_utf8(output.stdout).ok())
.and_then(|version| {
version
.trim()
.split('.')
.next()
.and_then(|major| major.parse::<u32>().ok())
})
.unwrap_or(0);

let mut args = vec!["-sdk", "macosx", "metal"];

// Only specify Metal 3.2 for metal logging if macOS version is 15.0 or higher
if macos_version >= 15 {
args.extend([
"-std=metal3.2",
"-target",
"air64-apple-macos15.0",
"-fmetal-enable-logging",
]);
}

args.extend(["-o", lib.as_str(), c.as_str()]);

Command::new("xcrun")
.args(args)
.output()
.expect("failed to compile")
} else {
panic!("Unsupported architecture");
};

if exe.stderr.len() != 0 {
panic!("{}", String::from_utf8(exe.stderr).unwrap());
}

lib
}

pub fn write_constants(
filepath: &str,
num_limbs: usize,
Expand Down Expand Up @@ -293,16 +189,6 @@ pub fn write_constants(
pub mod tests {
use super::*;

#[test]
#[serial_test::serial]
pub fn test_compile() {
let lib_filepath = compile_metal(
"../mopro-msm/src/msm/metal_msm/shader",
"bigint/bigint_add_unsafe.metal",
);
println!("{}", lib_filepath);
}

#[test]
#[serial_test::serial]
pub fn test_write_constants() {
Expand Down
15 changes: 12 additions & 3 deletions mopro-msm/src/msm/metal_msm/shader/cuzk/pbpr.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using namespace metal;

#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320)
#include <metal_logging>
constant os_log logger_kernel(/*subsystem=*/"pbpr", /*category=*/"metal");
#define LOG_DEBUG(...) logger_kernel.log_debug(__VA_ARGS__)
constant os_log pbpr_logger_kernel(/*subsystem=*/"pbpr", /*category=*/"metal");
#define LOG_DEBUG(...) pbpr_logger_kernel.log_debug(__VA_ARGS__)
#else
#define LOG_DEBUG(...) ((void)0)
#endif
Expand Down Expand Up @@ -45,9 +45,10 @@ kernel void bpr_stage_1(

const uint subtask_idx = params[0];
const uint num_columns = params[1];
const uint num_subtasks_per_bpr = params[2]; // Number of subtasks per shader invocation (must be power of 2).
const uint num_subtasks_per_bpr = params[2]; // Number of subtasks per shader invocation

const uint num_buckets_per_subtask = num_columns / 2u;
const uint total_buckets = num_buckets_per_subtask * num_subtasks_per_bpr;

// Number of buckets to reduce per thread.
const uint buckets_per_thread = num_buckets_per_subtask / num_threads_per_subtask;
Expand All @@ -59,6 +60,8 @@ kernel void bpr_stage_1(
if (thread_id % num_threads_per_subtask != 0u) {
idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * buckets_per_thread + offset;
}
// guard bucket bounds
if (idx >= total_buckets) { return; }

Jacobian m = {
.x = bucket_sum_x[idx],
Expand Down Expand Up @@ -86,6 +89,12 @@ kernel void bpr_stage_1(
bucket_sum_z[idx] = m.z;

uint g_rw_idx = (subtask_idx / num_subtasks_per_bpr) * (num_threads_per_subtask * num_subtasks_per_bpr) + thread_id;
// guard write into g_points buffers
if (g_rw_idx < num_threads_per_subtask * num_subtasks_per_bpr) {
g_points_x[g_rw_idx] = g.x;
g_points_y[g_rw_idx] = g.y;
g_points_z[g_rw_idx] = g.z;
}
g_points_x[g_rw_idx] = g.x;
g_points_y[g_rw_idx] = g.y;
g_points_z[g_rw_idx] = g.z;
Expand Down
Loading