1
- use std:: { collections:: { HashMap , HashSet } , ffi:: { c_void, CStr , CString } , ptr:: NonNull , sync:: { LazyLock , OnceLock } } ;
2
-
3
- macro_rules! make_cstr {
4
- ( $s: expr) => { {
5
- const BASE : & str = $s;
6
- const LEN : usize = BASE . len( ) + 1 ;
7
- const RET_P : [ u8 ; LEN ] = const {
8
- let mut ret: [ u8 ; LEN ] = [ 0 ; LEN ] ;
9
- let mut idx = 0 ;
10
- loop {
11
- if idx == LEN - 1 {
12
- break ;
13
- }
14
- ret[ idx] = BASE . as_bytes( ) [ idx] ;
15
- idx += 1 ;
16
- }
17
- ret
18
- } ;
19
- const { unsafe { std:: ffi:: CStr :: from_bytes_with_nul_unchecked( & RET_P ) } }
20
- } } ;
21
- }
1
+ //! Allows changing which functions are used (C or Rust) via environment variable
2
+ //!
3
+ //! Set `METIS_OVERRIDE_SYMS` to do so. See [`translation.md`](../translation.md) for more info.
4
+
5
+ use crate :: util:: make_cstr;
6
+ use std:: { borrow:: Cow , collections:: { HashMap , HashSet } , ffi:: { c_void, CStr , CString } , ptr:: NonNull , sync:: { LazyLock , OnceLock } } ;
22
7
23
8
pub static LIBMETIS : Library = Library :: new (
24
9
make_cstr ! ( env!( "LIBMETIS_PORTED" ) )
@@ -80,55 +65,98 @@ enum Version {
80
65
}
81
66
82
67
const VAR : & str = "METIS_OVERRIDE_SYMS" ;
83
- static SYM_OVERRIDES : LazyLock < HashMap < CString , Version > > = LazyLock :: new ( init_overrides) ;
84
- fn init_overrides ( ) -> HashMap < CString , Version > {
85
- use std:: io:: Write ;
86
- let Some ( args) = std:: env:: var_os ( VAR ) else {
87
- return HashMap :: new ( )
88
- } ;
89
- let Ok ( args) = args. into_string ( ) else {
90
- let mut out = std:: io:: stderr ( ) ;
91
- writeln ! ( out, "{VAR} is invalid utf-8" ) . unwrap ( ) ;
92
- return HashMap :: new ( )
93
- } ;
94
- let mut ret = HashMap :: new ( ) ;
95
- for arg in args. split ( ',' ) {
96
- let ( sym, spec) = if let Some ( split@( sym, _) ) = arg. split_once ( ':' ) {
97
- if sym == "c" || sym == "rs" {
98
- let mut out = std:: io:: stderr ( ) ;
99
- writeln ! ( out, "Schema: <symbol>:<version> OR <full_symbol>" ) . unwrap ( ) ;
100
- continue
101
- }
102
- split
103
- } else if let Some ( sym) = arg. strip_prefix ( "c__" ) {
104
- ( sym, "c" )
105
- } else if let Some ( sym) = arg. strip_prefix ( "rs__" ) {
106
- ( sym, "rs" )
107
- } else {
108
- ( arg, "c" )
68
+ static SYM_OVERRIDES : LazyLock < Overrides > = LazyLock :: new ( Overrides :: init_overrides) ;
69
+ #[ derive( Default ) ]
70
+ struct Overrides {
71
+ globs : Vec < ( Glob < ' static > , Version ) > ,
72
+ exact : HashMap < Box < [ u8 ] > , Version >
73
+ }
74
+
75
+ impl Overrides {
76
+ fn get ( & self , name : impl AsRef < [ u8 ] > ) -> Version {
77
+ let name = name. as_ref ( ) ;
78
+ let name = name. strip_suffix ( & [ 0u8 ] ) . unwrap_or ( name) ;
79
+ if let Some ( & exact_ver) = self . exact . get ( name) {
80
+ return exact_ver
81
+ }
82
+ let short_name = name. strip_prefix ( b"c__" ) . unwrap_or ( name) ;
83
+ let short_name = short_name. strip_prefix ( b"rs__" ) . unwrap_or ( short_name) ;
84
+ let short_name = short_name. strip_prefix ( b"libmetis__" ) . unwrap_or ( short_name) ;
85
+ if let Some ( & ( _, glob_ver) ) = self . globs . iter ( ) . rev ( ) . find ( |( glob, _) | glob. matches ( short_name) ) {
86
+ return glob_ver
87
+ }
88
+ Version :: Rust
89
+ }
90
+
91
+ fn init_overrides ( ) -> Self {
92
+ use std:: io:: Write ;
93
+ let Some ( args) = std:: env:: var_os ( VAR ) else {
94
+ return Overrides :: default ( ) ;
109
95
} ;
110
- let lib_pfx = if sym . starts_with ( "libmetis__" ) || EXPORTS . contains ( & sym ) {
111
- ""
112
- } else {
113
- "libmetis__"
96
+ let Ok ( args ) = args . into_string ( ) else {
97
+ let mut out = std :: io :: stderr ( ) ;
98
+ writeln ! ( out , "{VAR} is invalid utf-8" ) . unwrap ( ) ;
99
+ return Overrides :: default ( ) ;
114
100
} ;
115
- let ver = {
116
- match spec {
117
- "c" => Version :: C ,
118
- "rs" => Version :: Rust ,
119
- _ => {
101
+ let mut ret = Self :: default ( ) ;
102
+ for arg in args. split ( ',' ) {
103
+ if arg. contains ( '*' ) {
104
+ // this is a glob!
105
+ let ( glob, spec) = if let Some ( split) = arg. split_once ( ':' ) {
106
+ split
107
+ } else {
108
+ ( arg, "c" )
109
+ } ;
110
+ let ver = {
111
+ match spec {
112
+ "c" => Version :: C ,
113
+ "rs" => Version :: Rust ,
114
+ _ => {
115
+ let mut out = std:: io:: stderr ( ) ;
116
+ writeln ! ( out, "Bad spec: {spec:?}" ) . unwrap ( ) ;
117
+ continue
118
+ }
119
+ }
120
+ } ;
121
+ ret. globs . push ( ( Glob :: new_owned ( glob) , ver) ) ;
122
+ continue ;
123
+ }
124
+ let ( sym, spec) = if let Some ( split@( sym, _) ) = arg. split_once ( ':' ) {
125
+ if sym == "c" || sym == "rs" {
120
126
let mut out = std:: io:: stderr ( ) ;
121
- writeln ! ( out, "Bad spec: {spec:?} " ) . unwrap ( ) ;
127
+ writeln ! ( out, "Schema: <symbol>:<version> OR <full_symbol> " ) . unwrap ( ) ;
122
128
continue
123
129
}
124
- }
125
- } ;
126
- // always c__ since that's what we lookup with dlsym
127
- let sym = format ! ( "c__{lib_pfx}{sym}\0 " ) ;
128
- let sym = CString :: from_vec_with_nul ( sym. into_bytes ( ) ) . unwrap ( ) ;
129
- ret. insert ( sym, ver) ;
130
+ split
131
+ } else if let Some ( sym) = arg. strip_prefix ( "c__" ) {
132
+ ( sym, "c" )
133
+ } else if let Some ( sym) = arg. strip_prefix ( "rs__" ) {
134
+ ( sym, "rs" )
135
+ } else {
136
+ ( arg, "c" )
137
+ } ;
138
+ let lib_pfx = if sym. starts_with ( "libmetis__" ) || EXPORTS . contains ( & sym) {
139
+ ""
140
+ } else {
141
+ "libmetis__"
142
+ } ;
143
+ let ver = {
144
+ match spec {
145
+ "c" => Version :: C ,
146
+ "rs" => Version :: Rust ,
147
+ _ => {
148
+ let mut out = std:: io:: stderr ( ) ;
149
+ writeln ! ( out, "Bad spec: {spec:?}" ) . unwrap ( ) ;
150
+ continue
151
+ }
152
+ }
153
+ } ;
154
+ // always c__ since that's what we lookup with dlsym
155
+ let sym = format ! ( "c__{lib_pfx}{sym}" ) . into_bytes ( ) ;
156
+ ret. exact . insert ( sym. into ( ) , ver) ;
157
+ }
158
+ ret
130
159
}
131
- ret
132
160
}
133
161
134
162
fn clear_dlerror ( ) {
@@ -183,11 +211,7 @@ impl ICall {
183
211
// println!("{overrides:?}");
184
212
// panic!("");
185
213
* self . func . get_or_init ( || {
186
- let ver = if let Some ( & ver) = SYM_OVERRIDES . get ( self . sym_name ) {
187
- ver
188
- } else {
189
- Version :: Rust
190
- } ;
214
+ let ver = SYM_OVERRIDES . get ( self . sym_name . to_bytes ( ) ) ;
191
215
match ver {
192
216
Version :: Rust => self . rs_ver ,
193
217
Version :: C => {
@@ -209,3 +233,132 @@ impl ICall {
209
233
} )
210
234
}
211
235
}
236
+
237
+ /// Helper for grouping functions -- very dumb and can get very slow if there are too many `*`
238
+ pub struct Glob < ' a > {
239
+ template : Cow < ' a , [ u8 ] > ,
240
+ }
241
+
242
+ impl Glob < ' static > {
243
+ #[ allow( dead_code) ]
244
+ pub fn new_owned ( g : impl AsRef < [ u8 ] > ) -> Self {
245
+ Self {
246
+ template : Cow :: Owned ( g. as_ref ( ) . to_owned ( ) ) ,
247
+ }
248
+ }
249
+
250
+ }
251
+
252
+ impl < ' a > Glob < ' a > {
253
+ #[ allow( dead_code) ]
254
+ pub const fn new_str ( b : & ' a str ) -> Self {
255
+ Self {
256
+ template : Cow :: Borrowed ( b. as_bytes ( ) )
257
+ }
258
+ }
259
+
260
+ #[ allow( dead_code) ]
261
+ pub const fn new_bytes ( b : & ' a [ u8 ] ) -> Self {
262
+ Self {
263
+ template : Cow :: Borrowed ( b)
264
+ }
265
+ }
266
+
267
+ // TODO: optimize me!
268
+ pub fn matches ( & self , s : impl AsRef < [ u8 ] > ) -> bool {
269
+ fn slices ( s : & [ u8 ] ) -> impl Iterator < Item = & [ u8 ] > {
270
+ ( 0 ..s. len ( ) ) . map ( |i| & s[ i..] )
271
+ }
272
+ fn subslices < ' a > ( haystack : & ' a [ u8 ] , needle : & ' a [ u8 ] ) -> impl Iterator < Item = & ' a [ u8 ] > {
273
+ slices ( haystack) . filter_map ( |slice| slice. strip_prefix ( needle) )
274
+ }
275
+ fn initial ( mut g : & [ u8 ] , mut s : & [ u8 ] ) -> bool {
276
+ let Some ( star_idx) = g. iter ( ) . position ( |& c| c == b'*' ) else {
277
+ return g == s
278
+ } ;
279
+ if & g[ ..star_idx] != & s[ ..star_idx] {
280
+ return false
281
+ }
282
+ g = & g[ star_idx + 1 ..] ;
283
+ s = & s[ star_idx..] ;
284
+ inner ( g, s)
285
+ }
286
+ /// assumes g starts with an implicit `*`
287
+ fn inner ( mut g : & [ u8 ] , s : & [ u8 ] ) -> bool {
288
+ // eprintln!("called with => g: {:?}, s: {:?}", std::str::from_utf8(g), std::str::from_utf8(s));
289
+ let leading_stars = g. iter ( ) . take_while ( |& & c| c == b'*' ) . count ( ) ;
290
+ g = & g[ leading_stars..] ;
291
+ if g. is_empty ( ) {
292
+ // eprintln!("empty glob => {:?}", std::str::from_utf8(s));
293
+ return true
294
+ }
295
+ if let Some ( lit_len) = g. iter ( ) . position ( |& c| c == b'*' ) {
296
+ debug_assert ! ( lit_len >= 1 ) ;
297
+ let lit = & g[ ..lit_len] ;
298
+ g = & g[ lit_len..] ;
299
+ for s in subslices ( s, lit) {
300
+ if inner ( g, s) {
301
+ return true
302
+ }
303
+ }
304
+ false
305
+ } else {
306
+ s. ends_with ( g)
307
+ }
308
+ }
309
+ initial ( & * self . template , s. as_ref ( ) )
310
+ }
311
+ }
312
+
313
+
314
+ #[ cfg( test) ]
315
+ mod tests {
316
+ use super :: * ;
317
+
318
+ #[ test]
319
+ fn glob_matches ( ) {
320
+ #[ track_caller]
321
+ fn case ( glob : & str , haystack : & str ) {
322
+ eprintln ! ( "begin with ===> g: {glob:?}, s: {haystack:?}" ) ;
323
+ let g = Glob :: new_bytes ( glob. as_bytes ( ) ) ;
324
+ assert ! ( g. matches( haystack) , "glob {glob:?} did not match {haystack:?}" )
325
+ }
326
+ case ( "abc" , "abc" ) ;
327
+ case ( "a" , "a" ) ;
328
+ case ( "" , "" ) ;
329
+ case ( "*" , "abc" ) ;
330
+ case ( "*" , "a" ) ;
331
+ case ( "*" , "" ) ;
332
+ case ( "a*" , "a" ) ;
333
+ case ( "a*" , "abc" ) ;
334
+ case ( "*a*" , "abc" ) ;
335
+ case ( "*A*" , " A " ) ;
336
+ case ( "*A" , " A" ) ;
337
+ case ( "*A" , " AA" ) ;
338
+ case ( "*A" , " CBA" ) ;
339
+ case ( "*ABC" , "abcABC" ) ;
340
+ case ( "S*MID*E" , "S---MID---E" ) ;
341
+ case ( "S**MID**E" , "S---MID---E" ) ;
342
+ case ( "S*1*2*E" , "S---1--2---E" ) ;
343
+ case ( "S*1*2*E" , "S---12---E" ) ;
344
+ case ( "S*12*3*E" , "S---12--3---E" ) ;
345
+ }
346
+
347
+ #[ test]
348
+ fn glob_matches_not ( ) {
349
+ #[ track_caller]
350
+ fn case ( glob : & str , haystack : & str ) {
351
+ let g = Glob :: new_bytes ( glob. as_bytes ( ) ) ;
352
+ assert ! ( !g. matches( haystack) , "glob {glob:?} matched {haystack:?}" )
353
+ }
354
+ case ( "" , "abc" ) ;
355
+ case ( "a" , "ab" ) ;
356
+ case ( "a" , "ba" ) ;
357
+ case ( "a*" , "ba" ) ;
358
+ case ( "a*" , "bac" ) ;
359
+ case ( "*A" , "--A-" ) ;
360
+ case ( "*A" , "--AA-" ) ;
361
+ case ( "*A" , "A-" ) ;
362
+ case ( "S*1*2*E" , "S---13---E" ) ;
363
+ }
364
+ }
0 commit comments