-
Notifications
You must be signed in to change notification settings - Fork 27
feat: add gradient with AutoReactant #918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
Outdated
Show resolved
Hide resolved
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #918 +/- ##
==========================================
- Coverage 98.23% 2.40% -95.83%
==========================================
Files 133 101 -32
Lines 7968 5536 -2432
==========================================
- Hits 7827 133 -7694
- Misses 141 5403 +5262
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@wsmoses does this look better to you now? |
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
Outdated
Show resolved
Hide resolved
| @test check_inplace(backend) | ||
|
|
||
| test_differentiation( | ||
| backend, DifferentiationInterfaceTest.default_scenarios(; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test that the prep contains no data except the compiled fn if compiled for a reactant array
|
Modulo some comments being addressed above this looks reasonable to me. However note that there may be a potential mismatch in expectations behind what prepare gradient defines and what reactant compile defines. See https://enzymead.github.io/Reactant.jl/dev/tutorials/partial-evaluation Currently any data inside a constant or cache will be baked into the compiled function and will not be re read in later evaluation. Enzyme.jl in particular does not have any such constraint (as it will always re run with live data as prep is nothing). Something like reversediff compiled probably does bake in the assumption from compilation. So this is a question of what is the semantics of prep. If the non differentiated data is the same between prep and evaluation there is no difference between the two |
The semantics of
I thought converting the |
|
No your to_reac function does not achieve that. If you have a context of Tuple{Int, Int} this will not be converted unless you do However, concurrently, most of the time you actually want to partially evaluate integers in (e.g. for sizes/bounds/etc). I think the more reasonable setup here is to not to_rarray the context, and instead add a similar warning to the one from reversediff: |
|
Being forced to keep the same context values makes preparation pretty much useless. I think I'd rather have us trace everything in the context, even if it means a slowdown in some cases. Will it lead to actual errors? |
|
Or alternately we could restrict the kind of contexts we allow here |
|
Yes unnecessarily tracing objects can lead to errors that would fail to compile otherwise |
|
and it doesn't make it useless, it just means that the user is responsible for performing the to_rarray themselves for things that may change |
Can you give an example so that I wrap my mind around this?
That would be a Reactant-specific workaround, which doesn't fit other DI-supported backends. The whole point of DI is to enable easy backend switch, so I'd love to find a solution that doesn't expect users to wrap some of the arguments in Reactant-specific types when they want to switch to |
|
Besides, the problem is not specific to contexts: Maybe DI could expose a function like |
|
|
and I don't think this is terribly reactant-specific. The same core issue here equally applies to reversediff compiled [where the context will equally be baked it]. Just because reactant also has a way to circumvent the problem in special cases shouldn't mean it is treated differently here. |
bumping this @gdalle are you okay not to trace the contexts? |
|
Not really. Many use cases of DI that I can think of require changing contexts, and Reactant has to be relevant for these cases too. Re-compiling the derivatives for each context changes is very impractical. On the other hand, asking users to pass I don't have a lot of bandwith these days, but I think the right solution might be to expose something like |
|
In order to be differentiated the data must be a reactant array. If we assume that DI only officially supports array inputs this is fine |
"array" as in " |
|
By the way, this Reactant issue prevented me from testing |
|
@wsmoses we're still missing tests but I added detailed documentation on the behavior that you wanted us to enforce. Can you review it to make sure it is sensible? |
| - [`AutoTracker`](@extref ADTypes.AutoTracker) | ||
| - [`AutoZygote`](@extref ADTypes.AutoZygote) | ||
|
|
||
| In addition, we provide experimental support for [`AutoReactant`](@extref ADTypes.AutoReactant), sofar only for [`gradient`](@ref) and its variants. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the experimental word, but just say the incomplete implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example it's better supported here than Diffractor listed above as fully supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Diffractor is listed in the README as broken, but sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another argument for the "experimental" support is protecting us if we decide that we made an API mistake, because breaking an "experimental" feature is arguably ok within SemVer
| ) where {F, C} | ||
| _sig = DI.signature(f, rebackend, x; strict) | ||
| backend = rebackend.mode | ||
| xr = x isa ConcreteRArray ? nothing : ConcreteRArray(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If desired you could also do a check for abstract number and select ConcreteRNumber as a result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DI doesn't support gradient with number inputs for all backends (ForwardDiff will fail for instance), so the official API guarantee is for x::AbstractArray. Of course, reverse-mode backends are more generic than that, but I don't think it's bad to keep that restriction for now
| DI.check_prep(f, prep, rebackend, x) | ||
| backend = rebackend.mode | ||
| xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) | ||
| contextsr = map(_to_reactant, contexts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're going to use this design you need three functions to_reactant, copyto and copyfrom
Honestly I wouldn't put this functionality into DI atm, users can already do this before the DI call and it's easier to add functions than remove them later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly I wouldn't put this functionality into DI atm, users can already do this before the DI call and it's easier to add functions than remove them later
The point of putting it in DI is that the same code will work with different backends. If you force people to convert their arguments to Reactant types before the DI call, these types may interact badly with some other autodiff backends, leading to code that is Reactant-only.
By the same token, I don't ask users to create their own ForwardDiff.Dual storage, or their own symbolic numbers, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're going to use this design you need three functions to_reactant, copyto and copyfrom
Only for performance reasons, right? Since you said that Reactant's main use would be with non-traced Constants, the conversion will be a no-op anyway in such cases.
I'd be in favor of trying it this way and then thinking about copy utilities as a second step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, it's difficult to decide whether a type is mutable (hence amenable to copy!) or not. For instance, EnzymeAD/Enzyme.jl#2587 doesn't have a satisfying solution yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with that for the differentiated argument (which is what is done elsewhere too), but this is already extremely backend specific and seems likely to confuse people wrt the reactant.to_rarray and such.
Given the discussion above that the absence of these as a default is not going to be likely to cause problems, and it's easier to add APIs later without breakage than remove them, I think it's better to start simple and add this later if it becomes relevant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I strongly disagree and I would ask you to trust my gut on this one, based on DI usage I'm seeing in the wild. While scenarios that you care about mostly seem to include non-traceable constants, there are plenty of other cases where users will need the flexibility of picking which contexts they trace and how. If we don't give it to them, they may never even try the switch to Reactant, which I think would be a poor outcome.
Putting this behind a big "experimental: expect breaking stuff" warning, at least for the upcoming release, keeps us free to modify the API if we get the design wrong the first time. In addition, I've been playing with the idea of a similar "translation" API for other backends, so it may not end up being backend-specific after all. I could rename this into translate(backend, arg) and let users customize it also for dual or symbolic numbers depending on their needs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hear you, but I strongly disagree as well.
As to your specific point:
"they may never even try the switch to Reactant, which I think would be a poor outcome"
That is totally okay. Usage within DI is meant as a quick validation and ease of use. In most cases reactant will be called outside of any derivative calls, not the inside of like here.
More importantly, if we start small and simple here, we can expand based off of actual usage instead of pre-supposing usage. It also keeps the API a lot cleaner. However you can't as easily un-ring the bell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't expand based on actual usage if we only allow the narrowest possible usage you can think of. You know existing Reactant users, but I know existing DI users, which are very varied, and I also want to open the doors to applications and users that neither of us knows about yet.
And let me push back on DI mostly being used as quick prototyping, a temporary stepping stone. No doubt that some people using DI could get even more performance by switching to native Enzyme. But many of them won't, precisely because it's a hassle to refactor one more time. The same will likely happen with Reactant: sure, compiling the outer loop would be best, but I'm willing to bet plenty of people will be happy with just accelerating their inner gradient, and it will already change their life for the better.
A clean API is not useful if you cannot do anything with it. Only allowing preparation-time constants and no cache completely strips DI of its substance, and AutoReactant of its usefulness.
I'm the one who will document, test and maintain the API, so I'm willing to take responsibility for how clean it is. I haven't had too many complaints about DI being unclean so far, and even if I do get them, changing an experimental API is by definition not breaking. That's why I intend to keep the warning, so we can let people experiment without being afraid to backtrack if we got it wrong.
Updated experiments with Reactant-accelerated derivatives.
@wsmoses is this still the right paradigm in your opinion? I may not implement every operator right away but I thought starting with a gradient made sense
Related:
Warning
Re-toggle tests once this is mergeable