Skip to content

Generate Rusty wrapper by procedural macro, merge lapack crate #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
/Cargo.lock
/target
Cargo.lock
target/
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "lapack"]
path = lapack
path = lapack-sys/lapack
url = https://github.com/Reference-LAPACK/lapack
21 changes: 4 additions & 17 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
[package]
name = "lapack-sys"
version = "0.12.1"
license = "Apache-2.0/MIT"
authors = [
"Andrew Straw <[email protected]>",
"Corey Richardson <[email protected]>",
"Ivan Ukhov <[email protected]>",
[workspace]
members = [
"lapack-sys",
"lapack-derive",
]
description = "The package provides bindings to LAPACK (Fortran)."
documentation = "https://docs.rs/lapack-sys"
homepage = "https://github.com/blas-lapack-rs/lapack-sys"
repository = "https://github.com/blas-lapack-rs/lapack-sys"
readme = "README.md"
categories = ["external-ffi-bindings", "science"]
keywords = ["linear-algebra"]

[dependencies]
libc = "0.2"
13 changes: 13 additions & 0 deletions lapack-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "lapack-derive"
version = "0.1.0"
authors = ["Toshiki Teramura <[email protected]>"]
edition = "2018"

[lib]
proc-macro = true

[dependencies]
proc-macro2 = "1.0.18"
quote = "1.0.6"
syn = { version = "1.0.30", features = ["full", "extra-traits"] }
338 changes: 338 additions & 0 deletions lapack-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2, TokenTree};
use quote::quote;

type Args = syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>;
type Call = syn::punctuated::Punctuated<syn::Expr, syn::token::Comma>;

#[proc_macro_attribute]
pub fn lapack(_attr: TokenStream, func: TokenStream) -> TokenStream {
lapack2(syn::parse(func).unwrap()).into()
}

/// TokenStream2-based main routine
fn lapack2(func: TokenStream2) -> TokenStream2 {
let f = parse_foreign_fn(&func);
let wrap = wrap(&f);
quote! {
#func
#wrap
}
}

/// extern "C" { fn dgetrs_(...); } -> fn dgetrs_(...);
fn parse_foreign_fn(func: &TokenStream2) -> syn::ForeignItemFn {
let func = if let Some(func) = func.clone().into_iter().skip(2 /* 'extern', 'C' */).next() {
if let TokenTree::Group(group) = func {
group.stream()
} else {
unreachable!("#[lapack] attribute must be put to `extern \"C\"` block")
}
} else {
unreachable!("#[lapack] attribute must be put to `extern \"C\"` block")
};
syn::parse2(func).unwrap()
}

/// Generate token stream of wrapped function
fn wrap(f: &syn::ForeignItemFn) -> TokenStream2 {
// like dgetrs_
let lapack_sys_name = &f.sig.ident;
// like dgetrs
let lapack_name = lapack_sys_name
.to_string()
.trim_end_matches('_')
.to_string();
let lapack_name = syn::Ident::new(&lapack_name, Span::call_site());
let input = signature_input(&f.sig.inputs);
let call = call(&f.sig.inputs);
let output = &f.sig.output;
quote! {
pub unsafe fn #lapack_name ( #input ) #output {
#lapack_sys_name ( #call )
}
}
}

enum ArgType {
/// `T`
Value(String),
/// `*const T`
ConstPtr(String),
/// `*mut T`
MutPtr(String),
}

impl From<syn::TypePtr> for ArgType {
fn from(ptr_ty: syn::TypePtr) -> Self {
match &*ptr_ty.elem {
syn::Type::Path(path) => {
let path = quote! { #path }.to_string();
match ptr_ty.mutability {
Some(_) => ArgType::MutPtr(path),
None => ArgType::ConstPtr(path),
}
}
_ => unimplemented!("Pointer for non-path is not supported yet"),
}
}
}

impl From<syn::TypePath> for ArgType {
fn from(path: syn::TypePath) -> Self {
ArgType::Value(quote! { #path }.to_string())
}
}

/// Parse type ascription pattern `a: *mut f64` into ("a", "f64")
fn parse_input(pat: &syn::PatType) -> (String, ArgType) {
let name = match &*pat.pat {
syn::Pat::Ident(ident) => ident.ident.to_string(),
_ => unreachable!(),
};
let arg_type = match &*pat.ty {
syn::Type::Ptr(ptr_ty) => ptr_ty.clone().into(),
syn::Type::Path(path) => path.clone().into(),
_ => unimplemented!("Only Path and Pointer are supported yet"),
};
(name, arg_type)
}

fn is_value(name: &str) -> bool {
match name.to_lowercase().as_str() {
// sizes
"n" | "m" | "kl" | "ku" | "kd" | "nrhs" => true,
// flags
"itype" | "uplo" | "trans" | "balanc" | "sense" | "sort" => true,
// leading dimensions
name if name.starts_with("ld") => true,
// increments
name if name.starts_with("inc") => true,
// jobu / jobvt for SVD
name if name.starts_with("job") => true,
// pre-calculated norm value
name if name.ends_with("norm") => true,
// l*work is size of working memory
name if name.starts_with("l") && name.ends_with("work") => true,
_ => false,
}
}

fn is_mut_ref(name: &str) -> bool {
match name.to_lowercase().as_str() {
"info" => true,
// reciprocal of conditional number output
"rcond" => true,
// number of eigenvalues
"sdim" => true,
_ => false,
}
}

/// Convert pointer-based raw-LAPACK API into value and reference based API
fn signature_input(args: &Args) -> Args {
args.iter()
.cloned()
.map(|mut arg| {
match &mut arg {
syn::FnArg::Typed(arg) => {
let (name, arg_type) = parse_input(&arg);
let new_type = match arg_type {
ArgType::MutPtr(ty) => match name {
name if is_mut_ref(&name) => format!("&mut {}", ty),
_ => format!("&mut [{}]", ty),
},
ArgType::ConstPtr(ty) => match name {
name if is_value(&name) => format!("{}", ty),
_ => format!("&[{}]", ty),
},
ArgType::Value(ty) => ty,
};
*arg.ty = syn::parse_str(&new_type).unwrap();
}
_ => unreachable!("LAPACK raw API does not contains non-typed argument"),
}
arg
})
.collect()
}

fn call(args: &Args) -> Call {
args.iter()
.map(|arg| match arg {
syn::FnArg::Typed(arg) => {
let (name, arg_type) = parse_input(arg);
let expr = match arg_type {
ArgType::MutPtr(_) => match name {
name if is_mut_ref(&name) => name,
_ => format!("{}.as_mut_ptr()", name),
},
ArgType::ConstPtr(_) => match name {
name if is_value(&name) => format!("&{}", name),
_ => format!("{}.as_ptr()", name),
},
ArgType::Value(_) => name,
};
syn::parse_str::<syn::Expr>(&expr).unwrap()
}
_ => unreachable!(),
})
.collect()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn signature_input() {
let dgetrs = r#"
pub fn dgetrs_(
trans: *const c_char,
n: *const c_int,
nrhs: *const c_int,
A: *const f64,
lda: *const c_int,
ipiv: *const c_int,
B: *mut f64,
ldb: *const c_int,
info: *mut c_int,
);
"#;
let f: syn::ForeignItemFn = syn::parse_str(dgetrs).unwrap();
let result = super::signature_input(&f.sig.inputs);
let result_str = quote! { #result }.to_string();
let answer: TokenStream2 = syn::parse_str(
r#"
trans: c_char,
n: c_int,
nrhs: c_int,
A: &[f64],
lda: c_int,
ipiv: &[c_int],
B: &mut [f64],
ldb: c_int,
info: &mut c_int
"#,
)
.unwrap();
assert_eq!(result_str, answer.to_string());
}

#[test]
fn call() {
let dgetrs = r#"
pub fn dgetrs_(
trans: *const c_char,
n: *const c_int,
nrhs: *const c_int,
A: *const f64,
lda: *const c_int,
ipiv: *const c_int,
B: *mut f64,
ldb: *const c_int,
info: *mut c_int,
);
"#;
let f: syn::ForeignItemFn = syn::parse_str(dgetrs).unwrap();
let result = super::call(&f.sig.inputs);
let result_str = quote! { #result }.to_string();
let answer: TokenStream2 = syn::parse_str(
r#"
&trans,
&n,
&nrhs,
A.as_ptr(),
&lda,
ipiv.as_ptr(),
B.as_mut_ptr(),
&ldb,
info
"#,
)
.unwrap();
assert_eq!(result_str, answer.to_string());
}

#[test]
fn wrap_dgetrs() {
let dgetrs = r#"
pub fn dgetrs_(
trans: *const c_char,
n: *const c_int,
nrhs: *const c_int,
A: *const f64,
lda: *const c_int,
ipiv: *const c_int,
B: *mut f64,
ldb: *const c_int,
info: *mut c_int,
);
"#;
let wrapped = super::wrap(&syn::parse_str(dgetrs).unwrap());
let expected = r#"
pub unsafe fn dgetrs(
trans: c_char,
n: c_int,
nrhs: c_int,
A: &[f64],
lda: c_int,
ipiv: &[c_int],
B: &mut [f64],
ldb: c_int,
info: &mut c_int
) {
dgetrs_(
&trans,
&n,
&nrhs,
A.as_ptr(),
&lda,
ipiv.as_ptr(),
B.as_mut_ptr(),
&ldb,
info
)
}
"#;
let expected: TokenStream2 = syn::parse_str(expected).unwrap();
assert_eq!(wrapped.to_string(), expected.to_string());
}

/// Test for return value case
#[test]
fn wrap_dlange() {
let dgetrs = r#"
pub fn dlange_(
norm: *const c_char,
m: *const c_int,
n: *const c_int,
A: *const f64,
lda: *const c_int,
work: *mut f64,
) -> f64;
"#;
let wrapped = super::wrap(&syn::parse_str(dgetrs).unwrap());
let expected = r#"
pub unsafe fn dlange(
norm: c_char,
m: c_int,
n: c_int,
A: &[f64],
lda: c_int,
work: &mut [f64]
) -> f64 {
dlange_(
&norm,
&m,
&n,
A.as_ptr(),
&lda,
work.as_mut_ptr()
)
}
"#;
let expected: TokenStream2 = syn::parse_str(expected).unwrap();
assert_eq!(wrapped.to_string(), expected.to_string());
}
}
Loading