Skip to content

Conversation

@gdalle
Copy link
Member

@gdalle gdalle commented Nov 15, 2025

Updated experiments with Reactant-accelerated derivatives.

@wsmoses is this still the right paradigm in your opinion? I may not implement every operator right away but I thought starting with a gradient made sense

Related:

Warning

Re-toggle tests once this is mergeable

@codecov
Copy link

codecov bot commented Nov 16, 2025

Codecov Report

❌ Patch coverage is 91.66667% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 2.40%. Comparing base (0dd1abf) to head (f08d470).

Files with missing lines Patch % Lines
...e/ext/DifferentiationInterfaceReactantExt/utils.jl 0.00% 3 Missing ⚠️
DifferentiationInterface/src/misc/reactant.jl 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (0dd1abf) and HEAD (f08d470). Click for more details.

HEAD has 60 uploads less than BASE
Flag BASE (0dd1abf) HEAD (f08d470)
DI 50 1
DIT 11 0
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #918       +/-   ##
==========================================
- Coverage   98.23%   2.40%   -95.83%     
==========================================
  Files         133     101       -32     
  Lines        7968    5536     -2432     
==========================================
- Hits         7827     133     -7694     
- Misses        141    5403     +5262     
Flag Coverage Δ
DI 2.40% <91.66%> (-96.61%) ⬇️
DIT ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gdalle
Copy link
Member Author

gdalle commented Nov 16, 2025

@wsmoses does this look better to you now?
I'm not sure what we should do in terms of storage versus allocations. We can store xr (and even contextsr) during preparation and then copy to them at execution time instead of generating a new RArray, but that would require a copying method to be defined (which doesn't apply to all non-array objects).

@test check_inplace(backend)

test_differentiation(
backend, DifferentiationInterfaceTest.default_scenarios(;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test that the prep contains no data except the compiled fn if compiled for a reactant array

@wsmoses
Copy link

wsmoses commented Nov 16, 2025

Modulo some comments being addressed above this looks reasonable to me.

However note that there may be a potential mismatch in expectations behind what prepare gradient defines and what reactant compile defines.

See https://enzymead.github.io/Reactant.jl/dev/tutorials/partial-evaluation

Currently any data inside a constant or cache will be baked into the compiled function and will not be re read in later evaluation.

Enzyme.jl in particular does not have any such constraint (as it will always re run with live data as prep is nothing).

Something like reversediff compiled probably does bake in the assumption from compilation.

So this is a question of what is the semantics of prep.

If the non differentiated data is the same between prep and evaluation there is no difference between the two

@gdalle
Copy link
Member Author

gdalle commented Nov 16, 2025

So this is a question of what is the semantics of prep. If the non differentiated data is the same between prep and evaluation there is no difference between the two

The semantics of prep: differentiated and non-differentiated data are free to change between preparation and execution, as long as they keep the same types and sizes. See here for details.

Currently any data inside a constant or cache will be baked into the compiled function and will not be re read in later evaluation.

I thought converting the contexts into reactant arrays inside contextr would allow them to be traced? If that's true, then they won't be partially evaluated into the compiled function, which means the semantics of prep are respected?

@wsmoses
Copy link

wsmoses commented Nov 16, 2025

No your to_reac function does not achieve that. If you have a context of Tuple{Int, Int} this will not be converted unless you do to_rarray(context; track_numbers=Number).

However, concurrently, most of the time you actually want to partially evaluate integers in (e.g. for sizes/bounds/etc).

I think the more reasonable setup here is to not to_rarray the context, and instead add a similar warning to the one from reversediff:

These rules hold for the majority of backends, but there are some exceptions. The most important exception is [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) and its taping mechanism, which is sensitive to control flow inside the function.

@gdalle
Copy link
Member Author

gdalle commented Nov 16, 2025

Being forced to keep the same context values makes preparation pretty much useless. I think I'd rather have us trace everything in the context, even if it means a slowdown in some cases. Will it lead to actual errors?

@gdalle
Copy link
Member Author

gdalle commented Nov 16, 2025

Or alternately we could restrict the kind of contexts we allow here

@wsmoses
Copy link

wsmoses commented Nov 16, 2025

Yes unnecessarily tracing objects can lead to errors that would fail to compile otherwise

@wsmoses
Copy link

wsmoses commented Nov 16, 2025

and it doesn't make it useless, it just means that the user is responsible for performing the to_rarray themselves for things that may change

@gdalle
Copy link
Member Author

gdalle commented Nov 17, 2025

Yes unnecessarily tracing objects can lead to errors that would fail to compile otherwise

Can you give an example so that I wrap my mind around this?

and it doesn't make it useless, it just means that the user is responsible for performing the to_rarray themselves for things that may change

That would be a Reactant-specific workaround, which doesn't fit other DI-supported backends. The whole point of DI is to enable easy backend switch, so I'd love to find a solution that doesn't expect users to wrap some of the arguments in Reactant-specific types when they want to switch to AutoReactant.

@gdalle
Copy link
Member Author

gdalle commented Nov 17, 2025

Besides, the problem is not specific to contexts: x itself can contain integers we don't necessarily want to track. And a preparation that can only be reused if nothing at all changes will only ever be used once anyway, so it is pointless.

Maybe DI could expose a function like trace(a, backend) or translate(a, backend) which takes care of populating values in the correct way for differentiation / Reactant compilation? I already use such a function internally anyway, especially in ForwardDiff and other operator overloading-based backends.

@wsmoses
Copy link

wsmoses commented Nov 17, 2025

julia> using Reactant; x = Reactant.to_rarray(ones(10)); s = Reactant.ConcreteRNumber(2); e = Reactant.ConcreteRNumber(5);

julia> f(x, s, e) = x[s:e]
f (generic function with 1 method)

julia> @jit f(x, s, e)
ERROR: TypeError: non-boolean (Reactant.TracedRNumber{Bool}) used in boolean context
Stacktrace:
  [1] getindex_linear
    @ ~/git/Reactant.jl/src/Indexing.jl:340 [inlined]
  [2] (::Nothing)(none::typeof(Reactant.TracedIndexing.getindex_linear), none::Reactant.TracedRArray{Float64, 1}, none::Reactant.TracedUnitRange{Reactant.TracedRNumber{Int64}})
    @ Reactant ./<missing>:0
  [3] getindex_linear
    @ ~/git/Reactant.jl/src/Indexing.jl:339 [inlined]
  [4] call_with_reactant(::Reactant.MustThrowError, ::typeof(Reactant.TracedIndexing.getindex_linear), ::Reactant.TracedRArray{…}, ::Reactant.TracedUnitRange{…})
    @ Reactant ~/git/Reactant.jl/src/utils.jl:0
  [5] getindex
    @ ~/git/Reactant.jl/src/Indexing.jl:75 [inlined]
  [6] (::Nothing)(none::typeof(getindex), none::Reactant.TracedRArray{Float64, 1}, none::Reactant.TracedUnitRange{Reactant.TracedRNumber{Int64}})
    @ Reactant ./<missing>:0
  [7] getindex
    @ ~/git/Reactant.jl/src/Indexing.jl:75 [inlined]
  [8] call_with_reactant(::Reactant.MustThrowError, ::typeof(getindex), ::Reactant.TracedRArray{Float64, 1}, ::Reactant.TracedUnitRange{Reactant.TracedRNumber{Int64}})
    @ Reactant ~/git/Reactant.jl/src/utils.jl:0
  [9] f
    @ ./REPL[5]:1 [inlined]
 [10] (::Nothing)(none::typeof(f), none::Reactant.TracedRArray{Float64, 1}, none::Reactant.TracedRNumber{Int64}, none::Reactant.TracedRNumber{Int64})
    @ Reactant ./<missing>:0
 [11] TracedUnitRange
    @ ~/git/Reactant.jl/src/Types.jl:108 [inlined]
 [12] TracedUnitRange
    @ ~/git/Reactant.jl/src/TracedRange.jl:124 [inlined]
 [13] Colon
    @ ~/git/Reactant.jl/src/TracedRange.jl:181 [inlined]
 [14] f
    @ ./REPL[5]:1 [inlined]
 [15] call_with_reactant(::typeof(f), ::Reactant.TracedRArray{Float64, 1}, ::Reactant.TracedRNumber{Int64}, ::Reactant.TracedRNumber{Int64})
    @ Reactant ~/git/Reactant.jl/src/utils.jl:0
 [16] make_mlir_fn(f::typeof(f), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/git/Reactant.jl/src/TracedUtils.jl:345
 [17] make_mlir_fn
    @ ~/git/Reactant.jl/src/TracedUtils.jl:275 [inlined]
 [18] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(f), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:1608
 [19] compile_mlir!
    @ ~/git/Reactant.jl/src/Compiler.jl:1570 [inlined]
 [20] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:3516
 [21] compile_xla
    @ ~/git/Reactant.jl/src/Compiler.jl:3488 [inlined]
 [22] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:3592
 [23] top-level scope
    @ ~/git/Reactant.jl/src/Compiler.jl:2661
Some type information was truncated. Use `show(err)` to see complete types.

julia> @jit f(x, 2, 5)
4-element ConcretePJRTArray{Float64,1}:
 1.0
 1.0
 1.0
 1.0

@wsmoses
Copy link

wsmoses commented Nov 17, 2025

and I don't think this is terribly reactant-specific. The same core issue here equally applies to reversediff compiled [where the context will equally be baked it]. Just because reactant also has a way to circumvent the problem in special cases shouldn't mean it is treated differently here.

@wsmoses
Copy link

wsmoses commented Nov 18, 2025

and I don't think this is terribly reactant-specific. The same core issue here equally applies to reversediff compiled [where the context will equally be baked it]. Just because reactant also has a way to circumvent the problem in special cases shouldn't mean it is treated differently here.

bumping this @gdalle are you okay not to trace the contexts?

@gdalle
Copy link
Member Author

gdalle commented Nov 18, 2025

Not really. Many use cases of DI that I can think of require changing contexts, and Reactant has to be relevant for these cases too. Re-compiling the derivatives for each context changes is very impractical. On the other hand, asking users to pass RArrays instead of their normal arguments might make other backends fail, so the code on which DI runs is no longer fully generic, and I want to avoid that too. It would be like asking ForwardDiff users to pass arrays of Dual numbers.
Besides, that problem is not specific to contexts: what is stopping x itself (the active argument) from containing scalars that users may or may not want to trace?

I don't have a lot of bandwith these days, but I think the right solution might be to expose something like DI.to_reactant, telling users that the function will be called on every argument before Reactant compilation. That way, if they want to enforce specific tracing behavior, they can wrap their argument in a custom type and overload to_reactant, but it doesn't force them to

@wsmoses
Copy link

wsmoses commented Nov 18, 2025

In order to be differentiated the data must be a reactant array. If we assume that DI only officially supports array inputs this is fine

@gdalle
Copy link
Member Author

gdalle commented Nov 18, 2025

In order to be differentiated the data must be a reactant array. If we assume that DI only officially supports array inputs this is fine

"array" as in "RArray only" or as in "any nested struct of RArrays?

@gdalle
Copy link
Member Author

gdalle commented Nov 18, 2025

By the way, this Reactant issue prevented me from testing DI.Cache here, if you happen to have a quick fix lying around. I can also modify the DI EnzymeExt source code if this is expected behavior

@gdalle
Copy link
Member Author

gdalle commented Nov 28, 2025

@wsmoses we're still missing tests but I added detailed documentation on the behavior that you wanted us to enforce. Can you review it to make sure it is sensible?

- [`AutoTracker`](@extref ADTypes.AutoTracker)
- [`AutoZygote`](@extref ADTypes.AutoZygote)

In addition, we provide experimental support for [`AutoReactant`](@extref ADTypes.AutoReactant), sofar only for [`gradient`](@ref) and its variants.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the experimental word, but just say the incomplete implementation?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example it's better supported here than Diffractor listed above as fully supported

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diffractor is listed in the README as broken, but sure

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another argument for the "experimental" support is protecting us if we decide that we made an API mistake, because breaking an "experimental" feature is arguably ok within SemVer

) where {F, C}
_sig = DI.signature(f, rebackend, x; strict)
backend = rebackend.mode
xr = x isa ConcreteRArray ? nothing : ConcreteRArray(x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If desired you could also do a check for abstract number and select ConcreteRNumber as a result

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DI doesn't support gradient with number inputs for all backends (ForwardDiff will fail for instance), so the official API guarantee is for x::AbstractArray. Of course, reverse-mode backends are more generic than that, but I don't think it's bad to keep that restriction for now

DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x)
contextsr = map(_to_reactant, contexts)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to use this design you need three functions to_reactant, copyto and copyfrom

Honestly I wouldn't put this functionality into DI atm, users can already do this before the DI call and it's easier to add functions than remove them later

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I wouldn't put this functionality into DI atm, users can already do this before the DI call and it's easier to add functions than remove them later

The point of putting it in DI is that the same code will work with different backends. If you force people to convert their arguments to Reactant types before the DI call, these types may interact badly with some other autodiff backends, leading to code that is Reactant-only.
By the same token, I don't ask users to create their own ForwardDiff.Dual storage, or their own symbolic numbers, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to use this design you need three functions to_reactant, copyto and copyfrom

Only for performance reasons, right? Since you said that Reactant's main use would be with non-traced Constants, the conversion will be a no-op anyway in such cases.
I'd be in favor of trying it this way and then thinking about copy utilities as a second step.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it's difficult to decide whether a type is mutable (hence amenable to copy!) or not. For instance, EnzymeAD/Enzyme.jl#2587 doesn't have a satisfying solution yet

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that for the differentiated argument (which is what is done elsewhere too), but this is already extremely backend specific and seems likely to confuse people wrt the reactant.to_rarray and such.

Given the discussion above that the absence of these as a default is not going to be likely to cause problems, and it's easier to add APIs later without breakage than remove them, I think it's better to start simple and add this later if it becomes relevant

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly disagree and I would ask you to trust my gut on this one, based on DI usage I'm seeing in the wild. While scenarios that you care about mostly seem to include non-traceable constants, there are plenty of other cases where users will need the flexibility of picking which contexts they trace and how. If we don't give it to them, they may never even try the switch to Reactant, which I think would be a poor outcome.

Putting this behind a big "experimental: expect breaking stuff" warning, at least for the upcoming release, keeps us free to modify the API if we get the design wrong the first time. In addition, I've been playing with the idea of a similar "translation" API for other backends, so it may not end up being backend-specific after all. I could rename this into translate(backend, arg) and let users customize it also for dual or symbolic numbers depending on their needs.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hear you, but I strongly disagree as well.

As to your specific point:
"they may never even try the switch to Reactant, which I think would be a poor outcome"

That is totally okay. Usage within DI is meant as a quick validation and ease of use. In most cases reactant will be called outside of any derivative calls, not the inside of like here.

More importantly, if we start small and simple here, we can expand based off of actual usage instead of pre-supposing usage. It also keeps the API a lot cleaner. However you can't as easily un-ring the bell.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't expand based on actual usage if we only allow the narrowest possible usage you can think of. You know existing Reactant users, but I know existing DI users, which are very varied, and I also want to open the doors to applications and users that neither of us knows about yet.

And let me push back on DI mostly being used as quick prototyping, a temporary stepping stone. No doubt that some people using DI could get even more performance by switching to native Enzyme. But many of them won't, precisely because it's a hassle to refactor one more time. The same will likely happen with Reactant: sure, compiling the outer loop would be best, but I'm willing to bet plenty of people will be happy with just accelerating their inner gradient, and it will already change their life for the better.

A clean API is not useful if you cannot do anything with it. Only allowing preparation-time constants and no cache completely strips DI of its substance, and AutoReactant of its usefulness.
I'm the one who will document, test and maintain the API, so I'm willing to take responsibility for how clean it is. I haven't had too many complaints about DI being unclean so far, and even if I do get them, changing an experimental API is by definition not breaking. That's why I intend to keep the warning, so we can let people experiment without being afraid to backtrack if we got it wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants