@@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
77use std:: { fs, io, mem, str, thread} ;
88
99use rustc_ast:: attr;
10+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
1011use rustc_data_structures:: fx:: { FxHashMap , FxIndexMap } ;
1112use rustc_data_structures:: jobserver:: { self , Acquired } ;
1213use rustc_data_structures:: memmap:: Mmap ;
@@ -40,7 +41,7 @@ use tracing::debug;
4041use super :: link:: { self , ensure_removed} ;
4142use super :: lto:: { self , SerializedModule } ;
4243use super :: symbol_export:: symbol_name_for_instance_in_crate;
43- use crate :: errors:: ErrorCreatingRemarkDir ;
44+ use crate :: errors:: { AutodiffWithoutLto , ErrorCreatingRemarkDir } ;
4445use crate :: traits:: * ;
4546use crate :: {
4647 CachedModuleCodegen , CodegenResults , CompiledModule , CrateInfo , ModuleCodegen , ModuleKind ,
@@ -118,6 +119,7 @@ pub struct ModuleConfig {
118119 pub merge_functions : bool ,
119120 pub emit_lifetime_markers : bool ,
120121 pub llvm_plugins : Vec < String > ,
122+ pub autodiff : Vec < config:: AutoDiff > ,
121123}
122124
123125impl ModuleConfig {
@@ -266,6 +268,7 @@ impl ModuleConfig {
266268
267269 emit_lifetime_markers : sess. emit_lifetime_markers ( ) ,
268270 llvm_plugins : if_regular ! ( sess. opts. unstable_opts. llvm_plugins. clone( ) , vec![ ] ) ,
271+ autodiff : if_regular ! ( sess. opts. unstable_opts. autodiff. clone( ) , vec![ ] ) ,
269272 }
270273 }
271274
@@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
389392
390393fn generate_lto_work < B : ExtraBackendMethods > (
391394 cgcx : & CodegenContext < B > ,
395+ autodiff : Vec < AutoDiffItem > ,
392396 needs_fat_lto : Vec < FatLtoInput < B > > ,
393397 needs_thin_lto : Vec < ( String , B :: ThinBuffer ) > ,
394398 import_only_modules : Vec < ( SerializedModule < B :: ModuleBuffer > , WorkProduct ) > ,
@@ -399,9 +403,18 @@ fn generate_lto_work<B: ExtraBackendMethods>(
399403 assert ! ( needs_thin_lto. is_empty( ) ) ;
400404 let module =
401405 B :: run_fat_lto ( cgcx, needs_fat_lto, import_only_modules) . unwrap_or_else ( |e| e. raise ( ) ) ;
406+ if cgcx. lto == Lto :: Fat {
407+ let _config = cgcx. config ( ModuleKind :: Regular ) ;
408+ todo ! ( "fat LTO with autodiff is not yet implemented" ) ;
409+ //module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
410+ }
402411 // We are adding a single work item, so the cost doesn't matter.
403412 vec ! [ ( WorkItem :: LTO ( module) , 0 ) ]
404413 } else {
414+ if !autodiff. is_empty ( ) {
415+ let dcx = cgcx. create_dcx ( ) ;
416+ dcx. handle ( ) . emit_fatal ( AutodiffWithoutLto { } ) ;
417+ }
405418 assert ! ( needs_fat_lto. is_empty( ) ) ;
406419 let ( lto_modules, copy_jobs) = B :: run_thin_lto ( cgcx, needs_thin_lto, import_only_modules)
407420 . unwrap_or_else ( |e| e. raise ( ) ) ;
@@ -1021,6 +1034,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
10211034 /// Sent from a backend worker thread.
10221035 WorkItem { result : Result < WorkItemResult < B > , Option < WorkerFatalError > > , worker_id : usize } ,
10231036
1037+ /// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
1038+ AddAutoDiffItems ( Vec < AutoDiffItem > ) ,
1039+
10241040 /// The frontend has finished generating something (backend IR or a
10251041 /// post-LTO artifact) for a codegen unit, and it should be passed to the
10261042 /// backend. Sent from the main thread.
@@ -1348,6 +1364,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
13481364
13491365 // This is where we collect codegen units that have gone all the way
13501366 // through codegen and LLVM.
1367+ let mut autodiff_items = Vec :: new ( ) ;
13511368 let mut compiled_modules = vec ! [ ] ;
13521369 let mut compiled_allocator_module = None ;
13531370 let mut needs_link = Vec :: new ( ) ;
@@ -1459,9 +1476,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
14591476 let needs_thin_lto = mem:: take ( & mut needs_thin_lto) ;
14601477 let import_only_modules = mem:: take ( & mut lto_import_only_modules) ;
14611478
1462- for ( work, cost) in
1463- generate_lto_work ( & cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
1464- {
1479+ for ( work, cost) in generate_lto_work (
1480+ & cgcx,
1481+ autodiff_items. clone ( ) ,
1482+ needs_fat_lto,
1483+ needs_thin_lto,
1484+ import_only_modules,
1485+ ) {
14651486 let insertion_index = work_items
14661487 . binary_search_by_key ( & cost, |& ( _, cost) | cost)
14671488 . unwrap_or_else ( |e| e) ;
@@ -1596,6 +1617,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
15961617 main_thread_state = MainThreadState :: Idle ;
15971618 }
15981619
1620+ Message :: AddAutoDiffItems ( mut items) => {
1621+ autodiff_items. append ( & mut items) ;
1622+ }
1623+
15991624 Message :: CodegenComplete => {
16001625 if codegen_state != Aborted {
16011626 codegen_state = Completed ;
@@ -2070,6 +2095,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
20702095 drop ( self . coordinator . sender . send ( Box :: new ( Message :: CodegenComplete :: < B > ) ) ) ;
20712096 }
20722097
2098+ pub ( crate ) fn submit_autodiff_items ( & self , items : Vec < AutoDiffItem > ) {
2099+ drop ( self . coordinator . sender . send ( Box :: new ( Message :: < B > :: AddAutoDiffItems ( items) ) ) ) ;
2100+ }
2101+
20732102 pub ( crate ) fn check_for_errors ( & self , sess : & Session ) {
20742103 self . shared_emitter_main . check ( sess, false ) ;
20752104 }
0 commit comments