Skip to content

Commit 883a8dd

Browse files
committed
add globs to overrides
1 parent 9584925 commit 883a8dd

File tree

3 files changed

+257
-69
lines changed

3 files changed

+257
-69
lines changed

src/dyncall.rs

Lines changed: 222 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,9 @@
1-
use std::{collections::{HashMap, HashSet}, ffi::{c_void, CStr, CString}, ptr::NonNull, sync::{LazyLock, OnceLock}};
2-
3-
macro_rules! make_cstr {
4-
($s:expr) => {{
5-
const BASE: &str = $s;
6-
const LEN: usize = BASE.len() + 1;
7-
const RET_P: [u8; LEN] = const {
8-
let mut ret: [u8; LEN] = [0; LEN];
9-
let mut idx = 0;
10-
loop {
11-
if idx == LEN - 1 {
12-
break;
13-
}
14-
ret[idx] = BASE.as_bytes()[idx];
15-
idx += 1;
16-
}
17-
ret
18-
};
19-
const { unsafe { std::ffi::CStr::from_bytes_with_nul_unchecked(&RET_P) } }
20-
}};
21-
}
1+
//! Allows changing which functions are used (C or Rust) via environment variable
2+
//!
3+
//! Set `METIS_OVERRIDE_SYMS` to do so. See [`translation.md`](../translation.md) for more info.
4+
5+
use crate::util::make_cstr;
6+
use std::{borrow::Cow, collections::{HashMap, HashSet}, ffi::{c_void, CStr, CString}, ptr::NonNull, sync::{LazyLock, OnceLock}};
227

238
pub static LIBMETIS: Library = Library::new(
249
make_cstr!(env!("LIBMETIS_PORTED"))
@@ -80,55 +65,98 @@ enum Version {
8065
}
8166

8267
const VAR: &str = "METIS_OVERRIDE_SYMS";
83-
static SYM_OVERRIDES: LazyLock<HashMap<CString, Version>> = LazyLock::new(init_overrides);
84-
fn init_overrides() -> HashMap<CString, Version> {
85-
use std::io::Write;
86-
let Some(args) = std::env::var_os(VAR) else {
87-
return HashMap::new()
88-
};
89-
let Ok(args) = args.into_string() else {
90-
let mut out = std::io::stderr();
91-
writeln!(out, "{VAR} is invalid utf-8").unwrap();
92-
return HashMap::new()
93-
};
94-
let mut ret = HashMap::new();
95-
for arg in args.split(',') {
96-
let (sym, spec) = if let Some(split@(sym, _)) = arg.split_once(':') {
97-
if sym == "c" || sym == "rs" {
98-
let mut out = std::io::stderr();
99-
writeln!(out, "Schema: <symbol>:<version> OR <full_symbol>").unwrap();
100-
continue
101-
}
102-
split
103-
} else if let Some(sym) = arg.strip_prefix("c__") {
104-
(sym, "c")
105-
} else if let Some(sym) = arg.strip_prefix("rs__") {
106-
(sym, "rs")
107-
} else {
108-
(arg, "c")
68+
static SYM_OVERRIDES: LazyLock<Overrides> = LazyLock::new(Overrides::init_overrides);
69+
#[derive(Default)]
70+
struct Overrides {
71+
globs: Vec<(Glob<'static>, Version)>,
72+
exact: HashMap<Box<[u8]>, Version>
73+
}
74+
75+
impl Overrides {
76+
fn get(&self, name: impl AsRef<[u8]>) -> Version {
77+
let name = name.as_ref();
78+
let name = name.strip_suffix(&[0u8]).unwrap_or(name);
79+
if let Some(&exact_ver) = self.exact.get(name) {
80+
return exact_ver
81+
}
82+
let short_name = name.strip_prefix(b"c__").unwrap_or(name);
83+
let short_name = short_name.strip_prefix(b"rs__").unwrap_or(short_name);
84+
let short_name = short_name.strip_prefix(b"libmetis__").unwrap_or(short_name);
85+
if let Some(&(_, glob_ver)) = self.globs.iter().rev().find(|(glob, _)| glob.matches(short_name)) {
86+
return glob_ver
87+
}
88+
Version::Rust
89+
}
90+
91+
fn init_overrides() -> Self {
92+
use std::io::Write;
93+
let Some(args) = std::env::var_os(VAR) else {
94+
return Overrides::default();
10995
};
110-
let lib_pfx = if sym.starts_with("libmetis__") || EXPORTS.contains(&sym) {
111-
""
112-
} else {
113-
"libmetis__"
96+
let Ok(args) = args.into_string() else {
97+
let mut out = std::io::stderr();
98+
writeln!(out, "{VAR} is invalid utf-8").unwrap();
99+
return Overrides::default();
114100
};
115-
let ver = {
116-
match spec {
117-
"c" => Version::C,
118-
"rs" => Version::Rust,
119-
_ => {
101+
let mut ret = Self::default();
102+
for arg in args.split(',') {
103+
if arg.contains('*') {
104+
// this is a glob!
105+
let (glob, spec) = if let Some(split) = arg.split_once(':') {
106+
split
107+
} else {
108+
(arg, "c")
109+
};
110+
let ver = {
111+
match spec {
112+
"c" => Version::C,
113+
"rs" => Version::Rust,
114+
_ => {
115+
let mut out = std::io::stderr();
116+
writeln!(out, "Bad spec: {spec:?}").unwrap();
117+
continue
118+
}
119+
}
120+
};
121+
ret.globs.push((Glob::new_owned(glob), ver));
122+
continue;
123+
}
124+
let (sym, spec) = if let Some(split@(sym, _)) = arg.split_once(':') {
125+
if sym == "c" || sym == "rs" {
120126
let mut out = std::io::stderr();
121-
writeln!(out, "Bad spec: {spec:?}").unwrap();
127+
writeln!(out, "Schema: <symbol>:<version> OR <full_symbol>").unwrap();
122128
continue
123129
}
124-
}
125-
};
126-
// always c__ since that's what we lookup with dlsym
127-
let sym = format!("c__{lib_pfx}{sym}\0");
128-
let sym = CString::from_vec_with_nul(sym.into_bytes()).unwrap();
129-
ret.insert(sym, ver);
130+
split
131+
} else if let Some(sym) = arg.strip_prefix("c__") {
132+
(sym, "c")
133+
} else if let Some(sym) = arg.strip_prefix("rs__") {
134+
(sym, "rs")
135+
} else {
136+
(arg, "c")
137+
};
138+
let lib_pfx = if sym.starts_with("libmetis__") || EXPORTS.contains(&sym) {
139+
""
140+
} else {
141+
"libmetis__"
142+
};
143+
let ver = {
144+
match spec {
145+
"c" => Version::C,
146+
"rs" => Version::Rust,
147+
_ => {
148+
let mut out = std::io::stderr();
149+
writeln!(out, "Bad spec: {spec:?}").unwrap();
150+
continue
151+
}
152+
}
153+
};
154+
// always c__ since that's what we lookup with dlsym
155+
let sym = format!("c__{lib_pfx}{sym}").into_bytes();
156+
ret.exact.insert(sym.into(), ver);
157+
}
158+
ret
130159
}
131-
ret
132160
}
133161

134162
fn clear_dlerror() {
@@ -183,11 +211,7 @@ impl ICall {
183211
// println!("{overrides:?}");
184212
// panic!("");
185213
*self.func.get_or_init(|| {
186-
let ver = if let Some(&ver) = SYM_OVERRIDES.get(self.sym_name) {
187-
ver
188-
} else {
189-
Version::Rust
190-
};
214+
let ver = SYM_OVERRIDES.get(self.sym_name.to_bytes());
191215
match ver {
192216
Version::Rust => self.rs_ver,
193217
Version::C => {
@@ -209,3 +233,132 @@ impl ICall {
209233
})
210234
}
211235
}
236+
237+
/// Helper for grouping functions -- very dumb and can get very slow if there are too many `*`
238+
pub struct Glob<'a> {
239+
template: Cow<'a, [u8]>,
240+
}
241+
242+
impl Glob<'static> {
243+
#[allow(dead_code)]
244+
pub fn new_owned(g: impl AsRef<[u8]>) -> Self {
245+
Self {
246+
template: Cow::Owned(g.as_ref().to_owned()),
247+
}
248+
}
249+
250+
}
251+
252+
impl<'a> Glob<'a> {
253+
#[allow(dead_code)]
254+
pub const fn new_str(b: &'a str) -> Self {
255+
Self {
256+
template: Cow::Borrowed(b.as_bytes())
257+
}
258+
}
259+
260+
#[allow(dead_code)]
261+
pub const fn new_bytes(b: &'a [u8]) -> Self {
262+
Self {
263+
template: Cow::Borrowed(b)
264+
}
265+
}
266+
267+
// TODO: optimize me!
268+
pub fn matches(&self, s: impl AsRef<[u8]>) -> bool {
269+
fn slices(s: &[u8]) -> impl Iterator<Item = &[u8]> {
270+
(0..s.len()).map(|i| &s[i..])
271+
}
272+
fn subslices<'a>(haystack: &'a [u8], needle: &'a [u8]) -> impl Iterator<Item = &'a [u8]> {
273+
slices(haystack).filter_map(|slice| slice.strip_prefix(needle))
274+
}
275+
fn initial(mut g: &[u8], mut s: &[u8]) -> bool {
276+
let Some(star_idx) = g.iter().position(|&c| c == b'*') else {
277+
return g == s
278+
};
279+
if &g[..star_idx] != &s[..star_idx] {
280+
return false
281+
}
282+
g = &g[star_idx + 1..];
283+
s = &s[star_idx..];
284+
inner(g, s)
285+
}
286+
/// assumes g starts with an implicit `*`
287+
fn inner(mut g: &[u8], s: &[u8]) -> bool {
288+
// eprintln!("called with => g: {:?}, s: {:?}", std::str::from_utf8(g), std::str::from_utf8(s));
289+
let leading_stars = g.iter().take_while(|&&c| c == b'*').count();
290+
g = &g[leading_stars..];
291+
if g.is_empty() {
292+
// eprintln!("empty glob => {:?}", std::str::from_utf8(s));
293+
return true
294+
}
295+
if let Some(lit_len) = g.iter().position(|&c| c == b'*') {
296+
debug_assert!(lit_len >= 1);
297+
let lit = &g[..lit_len];
298+
g = &g[lit_len..];
299+
for s in subslices(s, lit) {
300+
if inner(g, s) {
301+
return true
302+
}
303+
}
304+
false
305+
} else {
306+
s.ends_with(g)
307+
}
308+
}
309+
initial(&*self.template, s.as_ref())
310+
}
311+
}
312+
313+
314+
#[cfg(test)]
315+
mod tests {
316+
use super::*;
317+
318+
#[test]
319+
fn glob_matches() {
320+
#[track_caller]
321+
fn case(glob: &str, haystack: &str) {
322+
eprintln!("begin with ===> g: {glob:?}, s: {haystack:?}");
323+
let g = Glob::new_bytes(glob.as_bytes());
324+
assert!(g.matches(haystack), "glob {glob:?} did not match {haystack:?}")
325+
}
326+
case("abc", "abc");
327+
case("a", "a");
328+
case("", "");
329+
case("*", "abc");
330+
case("*", "a");
331+
case("*", "");
332+
case("a*", "a");
333+
case("a*", "abc");
334+
case("*a*", "abc");
335+
case("*A*", " A ");
336+
case("*A", " A");
337+
case("*A", " AA");
338+
case("*A", " CBA");
339+
case("*ABC", "abcABC");
340+
case("S*MID*E", "S---MID---E");
341+
case("S**MID**E", "S---MID---E");
342+
case("S*1*2*E", "S---1--2---E");
343+
case("S*1*2*E", "S---12---E");
344+
case("S*12*3*E", "S---12--3---E");
345+
}
346+
347+
#[test]
348+
fn glob_matches_not() {
349+
#[track_caller]
350+
fn case(glob: &str, haystack: &str) {
351+
let g = Glob::new_bytes(glob.as_bytes());
352+
assert!(!g.matches(haystack), "glob {glob:?} matched {haystack:?}")
353+
}
354+
case("", "abc");
355+
case("a", "ab");
356+
case("a", "ba");
357+
case("a*", "ba");
358+
case("a*", "bac");
359+
case("*A", "--A-");
360+
case("*A", "--AA-");
361+
case("*A", "A-");
362+
case("S*1*2*E", "S---13---E");
363+
}
364+
}

src/util.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,28 @@ macro_rules! mkslice {
673673
};
674674
}
675675

676+
macro_rules! make_cstr {
677+
($s:expr) => {{
678+
const BASE: &str = $s;
679+
const LEN: usize = BASE.len() + 1;
680+
const RET_P: [u8; LEN] = const {
681+
let mut ret: [u8; LEN] = [0; LEN];
682+
let mut idx = 0;
683+
loop {
684+
if idx == LEN - 1 {
685+
break;
686+
}
687+
ret[idx] = BASE.as_bytes()[idx];
688+
idx += 1;
689+
}
690+
ret
691+
};
692+
const { unsafe { std::ffi::CStr::from_bytes_with_nul_unchecked(&RET_P) } }
693+
}};
694+
}
695+
pub(crate) use make_cstr;
696+
697+
676698
/// makes a slice or initialize a vec with a default value
677699
/// ```
678700
/// # use metis::slice_default;

translation.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,19 @@ METIS_OVERRIDE_SYMS="libmetis__SetupCoarseGraph:rs"
388388
METIS_OVERRIDE_SYMS="rs__libmetis__SetupCoarseGraph"
389389
```
390390

391+
If there's a `*` in the symbol, it will be treated as a glob. Exact matches
392+
always have priority over globs, and the last glob matched is the version used.
393+
Note that globs match against the base symbol, which excludes `c__libmetis__`.
394+
395+
```
396+
# only use C versions of functions
397+
METIS_OVERRIDE_SYMS="*"
398+
399+
# only use C versions of functions, except SetupCoarseGraph
400+
METIS_OVERRIDE_SYMS="*,SetupCoarseGraph:rs"
401+
```
402+
403+
391404
## Things to look out for
392405

393406
- make sure that the `gk_malloc` calls have null-terminated strings

0 commit comments

Comments
 (0)