diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 69ecf753e..400f45b56 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -16,71 +16,71 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: - test-DI-Core: - name: ${{ matrix.version }} - DI Core (${{ matrix.group }}) - runs-on: ubuntu-latest - if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - timeout-minutes: 120 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: false # TODO: toggle - matrix: - version: - - '1.10' - - '1.11' - - '1.12' - group: - - Internals - - SimpleFiniteDiff - - ZeroBackends - skip_lts: - - ${{ github.event.pull_request.draft }} - skip_pre: - - ${{ github.event.pull_request.draft }} - exclude: - - skip_lts: true - version: '1.10' - - skip_pre: true - version: '1.12' - env: - JULIA_DI_TEST_TYPE: 'Core' - JULIA_DI_TEST_GROUP: ${{ matrix.group }} - JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - steps: - - uses: actions/checkout@v6 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: x64 - - uses: julia-actions/cache@v2 - - name: Install dependencies & run tests - run: julia --color=yes -e ' - using Pkg; - Pkg.activate("./DifferentiationInterface/test"); - if VERSION < v"1.11"; - Pkg.rm("DifferentiationInterfaceTest"); - Pkg.resolve(); - else; - Pkg.develop(; path="./DifferentiationInterfaceTest"); - end; - Pkg.activate("./DifferentiationInterface"); - test_kwargs = (; allow_reresolve=false, coverage=true); - if ENV["JULIA_DI_PR_DRAFT"] == "true"; - Pkg.test("DifferentiationInterface"; julia_args=["-O1"], test_kwargs...); - else; - Pkg.test("DifferentiationInterface"; test_kwargs...); - end;' - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: ./DifferentiationInterface/src,./DifferentiationInterface/ext,./DifferentiationInterface/test - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - flags: DI - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false + # test-DI-Core: + # name: ${{ matrix.version }} - DI Core (${{ matrix.group }}) + # runs-on: ubuntu-latest + # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + # timeout-minutes: 120 + # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + # actions: write + # contents: read + # strategy: + # fail-fast: false # TODO: toggle + # matrix: + # version: + # - '1.10' + # - '1.11' + # - '1.12' + # group: + # - Internals + # - SimpleFiniteDiff + # - ZeroBackends + # skip_lts: + # - ${{ github.event.pull_request.draft }} + # skip_pre: + # - ${{ github.event.pull_request.draft }} + # exclude: + # - skip_lts: true + # version: '1.10' + # - skip_pre: true + # version: '1.12' + # env: + # JULIA_DI_TEST_TYPE: 'Core' + # JULIA_DI_TEST_GROUP: ${{ matrix.group }} + # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + # steps: + # - uses: actions/checkout@v6 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: ${{ matrix.version }} + # arch: x64 + # - uses: julia-actions/cache@v2 + # - name: Install dependencies & run tests + # run: julia --color=yes -e ' + # using Pkg; + # Pkg.activate("./DifferentiationInterface/test"); + # if VERSION < v"1.11"; + # Pkg.rm("DifferentiationInterfaceTest"); + # Pkg.resolve(); + # else; + # Pkg.develop(; path="./DifferentiationInterfaceTest"); + # end; + # Pkg.activate("./DifferentiationInterface"); + # test_kwargs = (; allow_reresolve=false, coverage=true); + # if ENV["JULIA_DI_PR_DRAFT"] == "true"; + # Pkg.test("DifferentiationInterface"; julia_args=["-O1"], test_kwargs...); + # else; + # Pkg.test("DifferentiationInterface"; test_kwargs...); + # end;' + # - uses: julia-actions/julia-processcoverage@v1 + # with: + # directories: ./DifferentiationInterface/src,./DifferentiationInterface/ext,./DifferentiationInterface/test + # - uses: codecov/codecov-action@v5 + # with: + # files: lcov.info + # flags: DI + # token: ${{ secrets.CODECOV_TOKEN }} + # fail_ci_if_error: false test-DI-Backend: name: ${{ matrix.version }} - DI Back (${{ matrix.group }}) @@ -94,26 +94,27 @@ jobs: fail-fast: false # TODO: toggle matrix: version: - - '1.10' + # - '1.10' - '1.11' - - '1.12' + # - '1.12' group: - - ChainRules - - DifferentiateWith + # - ChainRules + # - DifferentiateWith # - Diffractor - - Enzyme - - FastDifferentiation - - FiniteDiff - - FiniteDifferences - - ForwardDiff - - GTPSA - - Mooncake - - PolyesterForwardDiff - - ReverseDiff - - SparsityDetector - - Symbolics - - Tracker - - Zygote + # - Enzyme + # - FastDifferentiation + # - FiniteDiff + # - FiniteDifferences + # - ForwardDiff + # - GTPSA + # - Mooncake + # - PolyesterForwardDiff + - Reactant + # - ReverseDiff + # - SparsityDetector + # - Symbolics + # - Tracker + # - Zygote skip_lts: - ${{ github.event.pull_request.draft }} skip_pre: @@ -157,61 +158,61 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - test-DIT: - name: ${{ matrix.version }} - DIT (${{ matrix.group }}) - runs-on: ubuntu-latest - if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - timeout-minutes: 60 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: false # TODO: toggle - matrix: - version: - - '1.10' - - '1.11' - - '1.12' - group: - - Formalities - - Zero - - Standard - - Weird - skip_lts: - - ${{ github.event.pull_request.draft }} - skip_pre: - - ${{ github.event.pull_request.draft }} - exclude: - - skip_lts: true - version: '1.10' - - skip_pre: true - version: '1.12' - env: - JULIA_DIT_TEST_GROUP: ${{ matrix.group }} - JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - steps: - - uses: actions/checkout@v6 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: x64 - - uses: julia-actions/cache@v2 - - name: Install dependencies & run tests - run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' - using Pkg; - Pkg.Registry.update(); - Pkg.develop(path="./DifferentiationInterface"); - if ENV["JULIA_DI_PR_DRAFT"] == "true"; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); - else; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); - end;' - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - flags: DIT - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false + # test-DIT: + # name: ${{ matrix.version }} - DIT (${{ matrix.group }}) + # runs-on: ubuntu-latest + # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + # timeout-minutes: 60 + # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + # actions: write + # contents: read + # strategy: + # fail-fast: false # TODO: toggle + # matrix: + # version: + # - '1.10' + # - '1.11' + # - '1.12' + # group: + # - Formalities + # - Zero + # - Standard + # - Weird + # skip_lts: + # - ${{ github.event.pull_request.draft }} + # skip_pre: + # - ${{ github.event.pull_request.draft }} + # exclude: + # - skip_lts: true + # version: '1.10' + # - skip_pre: true + # version: '1.12' + # env: + # JULIA_DIT_TEST_GROUP: ${{ matrix.group }} + # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + # steps: + # - uses: actions/checkout@v6 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: ${{ matrix.version }} + # arch: x64 + # - uses: julia-actions/cache@v2 + # - name: Install dependencies & run tests + # run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' + # using Pkg; + # Pkg.Registry.update(); + # Pkg.develop(path="./DifferentiationInterface"); + # if ENV["JULIA_DI_PR_DRAFT"] == "true"; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); + # else; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); + # end;' + # - uses: julia-actions/julia-processcoverage@v1 + # with: + # directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test + # - uses: codecov/codecov-action@v5 + # with: + # files: lcov.info + # flags: DIT + # token: ${{ secrets.CODECOV_TOKEN }} + # fail_ci_if_error: false diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index ded9bd6c3..1cbea565f 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -21,6 +21,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" @@ -46,6 +47,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = [ "ForwardDiff", "DiffResults", ] +DifferentiationInterfaceReactantExt = ["Reactant", "Enzyme"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" @@ -56,7 +58,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.18.0" +ADTypes = "1.19.0" ChainRulesCore = "1.23.0" DiffResults = "1.1.0" Diffractor = "=0.2.6" @@ -71,6 +73,7 @@ GTPSA = "1.4.0" LinearAlgebra = "1" Mooncake = "0.4.175" PolyesterForwardDiff = "0.1.2" +Reactant = "0.2.178" ReverseDiff = "1.15.1" SparseArrays = "1" SparseConnectivityTracer = "0.6.14, 1" diff --git a/DifferentiationInterface/docs/Project.toml b/DifferentiationInterface/docs/Project.toml index 24bc8895a..67241f1fc 100644 --- a/DifferentiationInterface/docs/Project.toml +++ b/DifferentiationInterface/docs/Project.toml @@ -9,10 +9,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources] +DifferentiationInterface = {path = ".."} + [compat] ADTypes = "1" BenchmarkTools = "1" @@ -26,6 +30,3 @@ SparseConnectivityTracer = "1.1.2" SparseMatrixColorings = "0.4.23" Zygote = "0.7.10" julia = "1.10.10" - -[sources] -DifferentiationInterface = { path = ".." } diff --git a/DifferentiationInterface/docs/make.jl b/DifferentiationInterface/docs/make.jl index 334b5b83e..45ae93590 100644 --- a/DifferentiationInterface/docs/make.jl +++ b/DifferentiationInterface/docs/make.jl @@ -10,6 +10,7 @@ using Zygote: Zygote links = InterLinks( "ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/", + "Reactant" => "https://enzymead.github.io/Reactant.jl/stable/", "SparseConnectivityTracer" => "https://adrianhill.de/SparseConnectivityTracer.jl/stable/", "SparseMatrixColorings" => "https://gdalle.github.io/SparseMatrixColorings.jl/stable/", "Symbolics" => "https://symbolics.juliasymbolics.org/stable/", diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 108f1c03d..0d03227aa 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -144,3 +144,10 @@ DifferentiationInterface.AutoReverseFromPrimitive ```@docs DifferentiationInterface.Prep ``` + +### Reactant + +```@docs +AutoReactant +DifferentiationInterface.to_reactant +``` diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 0da5201a9..3c6eb66a9 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -19,6 +19,8 @@ We support the following dense backend choices from [ADTypes.jl](https://github. - [`AutoTracker`](@extref ADTypes.AutoTracker) - [`AutoZygote`](@extref ADTypes.AutoZygote) +In addition, we provide experimental support for [`AutoReactant`](@extref ADTypes.AutoReactant), sofar only for [`gradient`](@ref) and its variants. + ## Features Given a backend object, you can use: @@ -167,6 +169,10 @@ If a GTPSA [`Descriptor`](https://bmad-sim.github.io/GTPSA.jl/stable/man/b_descr Most operators fall back on `AutoForwardDiff`. +### Reactant + +See the docstring for [`AutoReactant`](@ref). + ### ReverseDiff With `AutoReverseDiff(compile=false)`, preparation preallocates a [config](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractConfig-API). diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl new file mode 100644 index 000000000..7740b53a6 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl @@ -0,0 +1,13 @@ +module DifferentiationInterfaceReactantExt + +using ADTypes: ADTypes, AutoReactant +import DifferentiationInterface as DI +using Reactant: @compile, ConcreteRArray, ConcreteRNumber, to_rarray + +DI.check_available(backend::AutoReactant) = DI.check_available(backend.mode) +DI.inplace_support(backend::AutoReactant) = DI.inplace_support(backend.mode) + +include("utils.jl") +include("onearg.jl") + +end # module diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl new file mode 100644 index 000000000..d0e4acebd --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl @@ -0,0 +1,78 @@ +struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{SIG} + _sig::Val{SIG} + xr::XR + gr::GR + compiled_gradient::CG + compiled_gradient!::CG! + compiled_value_and_gradient::CVG + compiled_value_and_gradient!::CVG! +end + +function DI.prepare_gradient_nokwarg( + strict::Val, f::F, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + _sig = DI.signature(f, rebackend, x; strict) + backend = rebackend.mode + xr = x isa ConcreteRArray ? nothing : ConcreteRArray(x) + gr = x isa ConcreteRArray ? nothing : ConcreteRArray(similar(x)) + contextsr = map(_to_reactant, contexts) + compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...) + compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...) + compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...) + compiled_value_and_gradient! = @compile DI.value_and_gradient!(f, gr, backend, xr, contextsr...) + return ReactantGradientPrep( + _sig, + xr, + gr, + compiled_gradient, + compiled_gradient!, + compiled_value_and_gradient, + compiled_value_and_gradient!, + ) +end + +function DI.gradient( + f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode + xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) + contextsr = map(_to_reactant, contexts) + gr = prep.compiled_gradient(f, backend, xr, contextsr...) + return gr +end + +function DI.value_and_gradient( + f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode + xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) + contextsr = map(_to_reactant, contexts) + yr, gr = prep.compiled_value_and_gradient(f, backend, xr, contextsr...) + return yr, gr +end + +function DI.gradient!( + f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode + xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) + gr = isnothing(prep.gr) ? grad : prep.gr + contextsr = map(_to_reactant, contexts) + prep.compiled_gradient!(f, gr, backend, xr, contextsr...) + return isnothing(prep.gr) ? grad : copyto!(grad, gr) +end + +function DI.value_and_gradient!( + f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} + DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode + xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) + gr = isnothing(prep.gr) ? grad : prep.gr + contextsr = map(_to_reactant, contexts) + yr, gr = prep.compiled_value_and_gradient!(f, gr, backend, xr, contextsr...) + return yr, isnothing(prep.gr) ? grad : copyto!(grad, gr) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl new file mode 100644 index 000000000..d1579f975 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl @@ -0,0 +1,3 @@ +_to_reactant(x) = DI.to_reactant(x) +_to_reactant(c::DI.Constant) = DI.Constant(_to_reactant(DI.unwrap(c))) +_to_reactant(c::DI.Cache) = DI.Cache(_to_reactant(DI.unwrap(c))) diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 4a07e7301..a3cddd305 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -30,6 +30,7 @@ using ADTypes: AutoMooncake, AutoMooncakeForward, AutoPolyesterForwardDiff, + AutoReactant, AutoReverseDiff, AutoSymbolics, AutoTracker, @@ -68,6 +69,7 @@ include("misc/sparsity_detector.jl") include("misc/simple_finite_diff.jl") include("misc/zero_backends.jl") include("misc/overloading.jl") +include("misc/reactant.jl") ## Exported @@ -118,6 +120,7 @@ export AutoGTPSA export AutoMooncake export AutoMooncakeForward export AutoPolyesterForwardDiff +export AutoReactant export AutoReverseDiff export AutoSymbolics export AutoTracker @@ -130,6 +133,7 @@ export AutoSparse @public inner, outer @public AutoForwardFromPrimitive, AutoReverseFromPrimitive @public Prep +@public to_reactant include("init.jl") diff --git a/DifferentiationInterface/src/misc/reactant.jl b/DifferentiationInterface/src/misc/reactant.jl new file mode 100644 index 000000000..0d0ba9219 --- /dev/null +++ b/DifferentiationInterface/src/misc/reactant.jl @@ -0,0 +1,84 @@ +""" +!!! tip "DI-specific information" + This part of the docstring is related to the use of `AutoReactant` inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), or DI for short. + Reactant's tutorial on [partial evaluation](https://enzymead.github.io/Reactant.jl/stable/tutorials/partial-evaluation) is useful reading to understand what follows. + +The `AutoReactant` backend inside DI imposes the following restrictions / assumptions: + +- The only supported operator (at the moment) is `DI.gradient` (along with its variants). +- The input `x` must be an `AbstractArray` such that `Reactant.ConcreteRArray(x)` is well-defined. +- By default, contexts such as `DI.Constant` and `DI.Cache` will be partially evaluated inside the compiled differentiation operator at preparation time. This means that the context value provided at preparation will be reused at every subsequent execution, while the context value provided at execution will be ignored. In particular, `DI.Cache` contexts will usually error and `DI.Constant` contexts will be frozen to one value. + +To disable partial evaluation and enforce tracing of contexts instead, first wrap them into types that _you own_. +Then, overload [`DifferentiationInterface.to_reactant`](@ref) on these types to perform tracing in the way you see fit, for instance with `Reactant.to_rarray`. +Every value you choose not to trace will still be partially evaluated at preparation time. + +# Example + +```jldoctest +using DifferentiationInterface +import DifferentiationInterface as DI +import Reactant + +struct MyArgument{T1 <: Number, T2 <: AbstractArray} + u::T1 + v::T2 +end + +f(x, a::MyArgument) = a.u * sum(a.v .* x .^ 2) + +DI.to_reactant(a::MyArgument) = Reactant.to_rarray(a; track_numbers = false) + +# preparation time +x0 = zeros(2) +a0 = MyArgument(1.0, [2.0, 3.0]) + +# execution time +x = [4.0, 5.0] +a = MyArgument(6.0, [7.0, 8.0]) + +backend = AutoReactant() +prep = prepare_gradient(f, backend, x0, Constant(a0)); + +g = gradient(f, prep, backend, x, Constant(a)) +g ≈ a0.u * 2 * (a.v .* x) # a0.u is partially evaluated, a0.v is traced + +# output + +true +``` +""" +AutoReactant + +""" + to_reactant(a) + +Convert an argument `a` to an object `ar` containing the same values, where all the fields and subfields that can contain active (differentiated) data have been translated to [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl) types such as [`ConcreteRArray`](@extref Reactant.ConcreteRArray) or [`ConcreteRNumber`](@extref Reactant.ConcreteRNumber). + +!!! danger + DifferentiationInterface.jl implements this function as the identity, on purpose. + It should not be overloaded on base types, but only on types that you own, to modify the default behavior of `AutoReactant`. + +# Example + +```jldoctest +import DifferentiationInterface as DI +import Reactant + +struct MyArgument{T1 <: Number, T2 <: AbstractArray} + u::T1 + v::T2 +end + +DI.to_reactant(a::MyArgument) = Reactant.to_rarray(a; track_numbers = false) + +a = MyArgument(1.0, [2.0, 3.0]) +ar = DI.to_reactant(a) +ar isa MyArgument{Float64, <:Reactant.ConcreteRArray} + +# output + +true +``` +""" +to_reactant(x) = x diff --git a/DifferentiationInterface/test/Back/Reactant/Project.toml b/DifferentiationInterface/test/Back/Reactant/Project.toml new file mode 100644 index 000000000..9ec59dc7b --- /dev/null +++ b/DifferentiationInterface/test/Back/Reactant/Project.toml @@ -0,0 +1,10 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/DifferentiationInterface/test/Back/Reactant/test.jl b/DifferentiationInterface/test/Back/Reactant/test.jl new file mode 100644 index 000000000..80dd7f6d1 --- /dev/null +++ b/DifferentiationInterface/test/Back/Reactant/test.jl @@ -0,0 +1,21 @@ +include("../../testutils.jl") + +using DifferentiationInterface +import DifferentiationInterface as DI +using DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT +import Enzyme, Reactant +using Test + +backend = AutoReactant() + +@test check_available(backend) +@test check_inplace(backend) + +test_differentiation( + backend, DifferentiationInterfaceTest.default_scenarios(; + include_constantified = false, include_cachified = false + ); + excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback), + logging = false +) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index a417d9eac..cb718f78c 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -10,6 +10,7 @@ Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -40,6 +41,7 @@ DataFrames = "1.6.1" DifferentiationInterface = "0.7.7" DocStringExtensions = "0.8,0.9" ForwardDiff = "0.10.36,1" +GPUArraysCore = "0.2.0" JET = "0.9,0.10,0.11" JLArrays = "0.1,0.2,0.3" LinearAlgebra = "1" diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 51b216915..84d9d2b38 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -92,6 +92,7 @@ using DifferentiationInterface: using DifferentiationInterface: Rewrap, Context, Constant, Cache, ConstantOrCache, unwrap using DifferentiationInterface: PreparationMismatchError using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES +using GPUArraysCore: @allowscalar using JET: @test_opt using LinearAlgebra: Adjoint, Diagonal, Transpose, I, dot, parent using PrecompileTools: @compile_workload diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 5761eb346..bef174384 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -224,8 +224,8 @@ Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") function (sc::StoreInCache{:out})(x, y_cache) # no annotation otherwise Zygote.Buffer cries y = sc.f(x) if y isa Number - y_cache[1] = y - return y_cache[1] + @allowscalar y_cache[1] = y + return @allowscalar y_cache[1] else copyto!(y_cache, y) return copy(y_cache)