@@ -51,6 +51,9 @@ unsafe extern "C" {
5151 pub ( crate ) fn LLVMGetNamedFunction ( M : & Module , Name : * const c_char ) -> Option < & Value > ;
5252}
5353
54+
55+
56+
5457#[ repr( C ) ]
5558#[ derive( Copy , Clone , PartialEq ) ]
5659pub ( crate ) enum LLVMRustVerifierFailureAction {
@@ -59,113 +62,201 @@ pub(crate) enum LLVMRustVerifierFailureAction {
5962 LLVMReturnStatusAction = 2 ,
6063}
6164
62- #[ cfg( not ( llvm_enzyme) ) ]
65+ #[ cfg( llvm_enzyme) ]
6366pub ( crate ) use self :: Enzyme_AD :: * ;
6467
65- #[ cfg( not ( llvm_enzyme) ) ]
68+ // #[cfg(llvm_enzyme)]
6669pub ( crate ) mod Enzyme_AD {
67- use std:: ffi:: { CString , c_char} ;
70+ use std:: ffi:: CString ;
71+ //use std::ffi::{CString, c_char};
6872
6973 use libc:: c_void;
7074
71- unsafe extern "C" {
72- pub ( crate ) fn EnzymeSetCLBool ( arg1 : * mut :: std:: os:: raw:: c_void , arg2 : u8 ) ;
73- pub ( crate ) fn EnzymeSetCLString ( arg1 : * mut :: std:: os:: raw:: c_void , arg2 : * const c_char ) ;
74- }
75- unsafe extern "C" {
76- static mut EnzymePrintPerf : c_void ;
77- static mut EnzymePrintActivity : c_void ;
78- static mut EnzymePrintType : c_void ;
79- static mut EnzymeFunctionToAnalyze : c_void ;
80- static mut EnzymePrint : c_void ;
81- static mut EnzymeStrictAliasing : c_void ;
82- static mut looseTypeAnalysis: c_void ;
83- static mut EnzymeInline : c_void ;
84- static mut RustTypeRules : c_void ;
75+ type SetFlag = unsafe extern "C" fn ( * mut c_void , u8 ) ;
76+
77+ #[ derive( Debug ) ]
78+ pub ( crate ) struct EnzymeFns {
79+ pub set_cl : SetFlag ,
80+ }
81+
82+ #[ derive( Debug ) ]
83+ pub ( crate ) struct EnzymeWrapper {
84+ EnzymePrintPerf : * mut c_void ,
85+ EnzymePrintActivity : * mut c_void ,
86+ EnzymePrintType : * mut c_void ,
87+ EnzymeFunctionToAnalyze : * mut c_void ,
88+ EnzymePrint : * mut c_void ,
89+ EnzymeStrictAliasing : * mut c_void ,
90+ looseTypeAnalysis : * mut c_void ,
91+ EnzymeInline : * mut c_void ,
92+ RustTypeRules : * mut c_void ,
93+
94+ EnzymeSetCLBool : EnzymeFns ,
95+ EnzymeSetCLString : EnzymeFns ,
96+ pub registerEnzymeAndPassPipeline : * const c_void ,
8597 }
86- pub ( crate ) fn set_print_perf ( print : bool ) {
87- unsafe {
88- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymePrintPerf ) , print as u8 ) ;
98+ fn call_dynamic ( ) -> Result < EnzymeWrapper , Box < dyn std:: error:: Error > > {
99+ fn load_ptr ( lib : & libloading:: Library , bytes : & [ u8 ] ) -> Result < * mut c_void , Box < dyn std:: error:: Error > > {
100+ // Safety: symbol lookup from a loaded shared object.
101+ unsafe {
102+ let s: libloading:: Symbol < ' _ , * mut c_void > = lib. get ( bytes) ?;
103+ let s = s. try_as_raw_ptr ( ) . unwrap ( ) ;
104+ Ok ( s as * mut c_void )
105+ }
106+ }
107+ dbg ! ( "starting" ) ;
108+ dbg ! ( "Loading Enzyme" ) ;
109+ let lib = unsafe { libloading:: Library :: new ( "/home/manuel/prog/rust/build/x86_64-unknown-linux-gnu/enzyme/lib/libEnzyme-21.so" ) ?} ;
110+ dbg ! ( "second" ) ;
111+ let EnzymeSetCLBool : libloading:: Symbol < ' _ , SetFlag > = unsafe { lib. get ( b"EnzymeSetCLBool" ) ?} ;
112+ dbg ! ( "third" ) ;
113+ let registerEnzymeAndPassPipeline =
114+ load_ptr ( & lib, b"registerEnzymeAndPassPipeline" ) . unwrap ( ) as * const c_void ;
115+ dbg ! ( "fourth" ) ;
116+ //let EnzymeSetCLBool: libloading::Symbol<'_, unsafe extern "C" fn(&mut c_void, u8) -> ()> = unsafe{lib.get(b"registerEnzymeAndPassPipeline")?};
117+ //let EnzymeSetCLBool = unsafe {EnzymeSetCLBool.try_as_raw_ptr().unwrap()};
118+ let EnzymeSetCLString : libloading:: Symbol < ' _ , SetFlag > = unsafe { lib. get ( b"EnzymeSetCLString" ) ?} ;
119+ dbg ! ( "done" ) ;
120+ //let EnzymeSetCLString = unsafe {EnzymeSetCLString.try_as_raw_ptr().unwrap()};
121+
122+ let EnzymePrintPerf = load_ptr ( & lib, b"EnzymePrintPerf" ) . unwrap ( ) ;
123+ let EnzymePrintActivity = load_ptr ( & lib, b"EnzymePrintActivity" ) . unwrap ( ) ;
124+ let EnzymePrintType = load_ptr ( & lib, b"EnzymePrintType" ) . unwrap ( ) ;
125+ let EnzymeFunctionToAnalyze = load_ptr ( & lib, b"EnzymeFunctionToAnalyze" ) . unwrap ( ) ;
126+ let EnzymePrint = load_ptr ( & lib, b"EnzymePrint" ) . unwrap ( ) ;
127+
128+ let EnzymeStrictAliasing = load_ptr ( & lib, b"EnzymeStrictAliasing" ) . unwrap ( ) ;
129+ let looseTypeAnalysis = load_ptr ( & lib, b"looseTypeAnalysis" ) . unwrap ( ) ;
130+ let EnzymeInline = load_ptr ( & lib, b"EnzymeInline" ) . unwrap ( ) ;
131+ let RustTypeRules = load_ptr ( & lib, b"RustTypeRules" ) . unwrap ( ) ;
132+
133+ let wrap = EnzymeWrapper {
134+ EnzymePrintPerf ,
135+ EnzymePrintActivity ,
136+ EnzymePrintType ,
137+ EnzymeFunctionToAnalyze ,
138+ EnzymePrint ,
139+ EnzymeStrictAliasing ,
140+ looseTypeAnalysis,
141+ EnzymeInline ,
142+ RustTypeRules ,
143+ //EnzymeSetCLBool: EnzymeFns {set_cl: unsafe{*EnzymeSetCLBool}},
144+ //EnzymeSetCLString: EnzymeFns {set_cl: unsafe{*EnzymeSetCLString}},
145+ EnzymeSetCLBool : EnzymeFns { set_cl : * EnzymeSetCLBool } ,
146+ EnzymeSetCLString : EnzymeFns { set_cl : * EnzymeSetCLString } ,
147+ registerEnzymeAndPassPipeline,
148+ } ;
149+ dbg ! ( & wrap) ;
150+ Ok ( wrap)
89151 }
90- }
91- pub ( crate ) fn set_print_activity ( print : bool ) {
92- unsafe {
93- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymePrintActivity ) , print as u8 ) ;
152+ use std:: sync:: Mutex ;
153+ unsafe impl Sync for EnzymeWrapper { }
154+ unsafe impl Send for EnzymeWrapper { }
155+ impl EnzymeWrapper {
156+ pub ( crate ) fn current ( ) -> & ' static Mutex < EnzymeWrapper > {
157+ use std:: sync:: OnceLock ;
158+ static CELL : OnceLock < Mutex < EnzymeWrapper > > = OnceLock :: new ( ) ;
159+ fn init_enzyme ( ) -> Mutex < EnzymeWrapper > {
160+ call_dynamic ( ) . unwrap ( ) . into ( )
161+ }
162+ CELL . get_or_init ( || init_enzyme ( ) )
94163 }
95- }
96- pub ( crate ) fn set_print_type ( print : bool ) {
97- unsafe {
98- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymePrintType ) , print as u8 ) ;
164+ pub ( crate ) fn set_print_perf ( & mut self , print : bool ) {
165+ unsafe {
166+ //(self.EnzymeSetCLBool.set_cl)(self.EnzymePrintPerf, print as u8);
167+ //(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintPerf), print as u8);
168+ }
99169 }
100- }
101- pub ( crate ) fn set_print_type_fun ( fun_name : & str ) {
102- let c_fun_name = CString :: new ( fun_name) . unwrap ( ) ;
103- unsafe {
104- EnzymeSetCLString (
105- std:: ptr:: addr_of_mut!( EnzymeFunctionToAnalyze ) ,
106- c_fun_name. as_ptr ( ) as * const c_char ,
107- ) ;
170+
171+ pub ( crate ) fn set_print_activity ( & mut self , print : bool ) {
172+ unsafe {
173+ //(self.EnzymeSetCLBool.set_cl)(self.EnzymePrintActivity, print as u8);
174+ //(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintActivity), print as u8);
175+ }
108176 }
109- }
110- pub ( crate ) fn set_print ( print : bool ) {
111- unsafe {
112- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymePrint ) , print as u8 ) ;
177+
178+ pub ( crate ) fn set_print_type ( & mut self , print : bool ) {
179+ unsafe {
180+ // (self.EnzymeSetCLBool.set_cl)(self.EnzymePrintType, print as u8);
181+ }
113182 }
114- }
115- pub ( crate ) fn set_strict_aliasing ( strict : bool ) {
116- unsafe {
117- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymeStrictAliasing ) , strict as u8 ) ;
183+
184+ pub ( crate ) fn set_print_type_fun ( & mut self , fun_name : & str ) {
185+ let _c_fun_name = CString :: new ( fun_name) . unwrap ( ) ;
186+ //unsafe {
187+ // (self.EnzymeSetCLString.set_cl)(
188+ // self.EnzymeFunctionToAnalyze,
189+ // c_fun_name.as_ptr() as *const c_char,
190+ // );
191+ //}
118192 }
119- }
120- pub ( crate ) fn set_loose_types ( loose : bool ) {
121- unsafe {
122- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( looseTypeAnalysis) , loose as u8 ) ;
193+
194+ pub ( crate ) fn set_print ( & mut self , print : bool ) {
195+ unsafe {
196+ //(self.EnzymeSetCLBool.set_cl)(self.EnzymePrint, print as u8);
197+ }
123198 }
124- }
125- pub ( crate ) fn set_inline ( val : bool ) {
126- unsafe {
127- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymeInline ) , val as u8 ) ;
199+
200+ pub ( crate ) fn set_strict_aliasing ( & mut self , strict : bool ) {
201+ unsafe {
202+ //(self.EnzymeSetCLBool.set_cl)(self.EnzymeStrictAliasing, strict as u8);
203+ }
128204 }
129- }
130- pub ( crate ) fn set_rust_rules ( val : bool ) {
131- unsafe {
132- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( RustTypeRules ) , val as u8 ) ;
205+
206+ pub ( crate ) fn set_loose_types ( & mut self , loose : bool ) {
207+ unsafe {
208+ //(self.EnzymeSetCLBool.set_cl)(self.looseTypeAnalysis, loose as u8);
209+ }
210+ }
211+
212+ pub ( crate ) fn set_inline ( & mut self , val : bool ) {
213+ unsafe {
214+ //(self.EnzymeSetCLBool.set_cl)(self.EnzymeInline, val as u8);
215+ }
216+ }
217+
218+ pub ( crate ) fn set_rust_rules ( & mut self , val : bool ) {
219+ unsafe {
220+ //(self.EnzymeSetCLBool.set_cl)(self.RustTypeRules, val as u8);
221+ }
133222 }
134223 }
224+
225+
135226}
136227
137- #[ cfg( llvm_enzyme) ]
228+ #[ cfg( not ( llvm_enzyme) ) ]
138229pub ( crate ) use self :: Fallback_AD :: * ;
139230
140- #[ cfg( llvm_enzyme) ]
231+ #[ cfg( not ( llvm_enzyme) ) ]
141232pub ( crate ) mod Fallback_AD {
142233 #![ allow( unused_variables) ]
143234
144235 pub ( crate ) fn set_inline ( val : bool ) {
145- // unimplemented!()
236+ unimplemented ! ( )
146237 }
147238 pub ( crate ) fn set_print_perf ( print : bool ) {
148- // unimplemented!()
239+ unimplemented ! ( )
149240 }
150241 pub ( crate ) fn set_print_activity ( print : bool ) {
151- // unimplemented!()
242+ unimplemented ! ( )
152243 }
153244 pub ( crate ) fn set_print_type ( print : bool ) {
154- // unimplemented!()
245+ unimplemented ! ( )
155246 }
156247 pub ( crate ) fn set_print_type_fun ( fun_name : & str ) {
157- // unimplemented!()
248+ unimplemented ! ( )
158249 }
159250 pub ( crate ) fn set_print ( print : bool ) {
160- // unimplemented!()
251+ unimplemented ! ( )
161252 }
162253 pub ( crate ) fn set_strict_aliasing ( strict : bool ) {
163- // unimplemented!()
254+ unimplemented ! ( )
164255 }
165256 pub ( crate ) fn set_loose_types ( loose : bool ) {
166- // unimplemented!()
257+ unimplemented ! ( )
167258 }
168259 pub ( crate ) fn set_rust_rules ( val : bool ) {
169- // unimplemented!()
260+ unimplemented ! ( )
170261 }
171262}
0 commit comments