From 2f135972a9fa926c4addb9188ee9eb89a051e529 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Nov 2025 00:55:29 +0100 Subject: [PATCH 1/5] feat: add gradient with AutoReactant --- .github/workflows/Test.yml | 159 +++++++++--------- DifferentiationInterface/Project.toml | 5 +- .../DifferentiationInterfaceReactantExt.jl | 12 ++ .../onearg.jl | 81 +++++++++ .../src/DifferentiationInterface.jl | 2 + .../test/Back/Reactant/test.jl | 14 ++ 6 files changed, 194 insertions(+), 79 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl create mode 100644 DifferentiationInterface/test/Back/Reactant/test.jl diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 89223f597..e490a0a7c 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -28,29 +28,30 @@ jobs: fail-fast: true # TODO: toggle matrix: version: - - '1.10' + # - '1.10' - '1.11' - '1.12' group: - - Core/Internals - - Back/DifferentiateWith - - Core/SimpleFiniteDiff - - Back/SparsityDetector - - Core/ZeroBackends - - Back/ChainRules - # - Back/Diffractor - - Back/Enzyme - - Back/FastDifferentiation - - Back/FiniteDiff - - Back/FiniteDifferences - - Back/ForwardDiff - - Back/GTPSA - - Back/Mooncake - - Back/PolyesterForwardDiff - - Back/ReverseDiff - - Back/Symbolics - - Back/Tracker - - Back/Zygote + # - Core/Internals + # - Back/DifferentiateWith + # - Core/SimpleFiniteDiff + # - Back/SparsityDetector + # - Core/ZeroBackends + # - Back/ChainRules + # # - Back/Diffractor + # - Back/Enzyme + # - Back/FastDifferentiation + # - Back/FiniteDiff + # - Back/FiniteDifferences + # - Back/ForwardDiff + # - Back/GTPSA + # - Back/Mooncake + # - Back/PolyesterForwardDiff + - Back/Reactant + # - Back/ReverseDiff + # - Back/Symbolics + # - Back/Tracker + # - Back/Zygote skip_lts: - ${{ github.event.pull_request.draft }} skip_pre: @@ -64,6 +65,8 @@ jobs: group: Back/ChainRules - version: '1.12' group: Back/Enzyme + - version: '1.12' + group: Back/Reactant - version: '1.12' group: Back/DifferentiateWith env: @@ -104,61 +107,61 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - test-DIT: - name: ${{ matrix.version }} - DIT (${{ matrix.group }}) - runs-on: ubuntu-latest - if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - timeout-minutes: 60 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: true - matrix: - version: - - '1.10' - - '1.11' - - '1.12' - group: - - Formalities - - Zero - - Standard - - Weird - skip_lts: - - ${{ github.event.pull_request.draft }} - skip_pre: - - ${{ github.event.pull_request.draft }} - exclude: - - skip_lts: true - version: '1.10' - - skip_pre: true - version: '1.12' - env: - JULIA_DIT_TEST_GROUP: ${{ matrix.group }} - JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - steps: - - uses: actions/checkout@v5 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: x64 - - uses: julia-actions/cache@v2 - - name: Install dependencies & run tests - run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' - using Pkg; - Pkg.Registry.update(); - Pkg.develop(path="./DifferentiationInterface"); - if ENV["JULIA_DI_PR_DRAFT"] == "true"; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); - else; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); - end;' - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - flags: DIT - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false + # test-DIT: + # name: ${{ matrix.version }} - DIT (${{ matrix.group }}) + # runs-on: ubuntu-latest + # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + # timeout-minutes: 60 + # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + # actions: write + # contents: read + # strategy: + # fail-fast: true + # matrix: + # version: + # - '1.10' + # - '1.11' + # - '1.12' + # group: + # - Formalities + # - Zero + # - Standard + # - Weird + # skip_lts: + # - ${{ github.event.pull_request.draft }} + # skip_pre: + # - ${{ github.event.pull_request.draft }} + # exclude: + # - skip_lts: true + # version: '1.10' + # - skip_pre: true + # version: '1.12' + # env: + # JULIA_DIT_TEST_GROUP: ${{ matrix.group }} + # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + # steps: + # - uses: actions/checkout@v5 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: ${{ matrix.version }} + # arch: x64 + # - uses: julia-actions/cache@v2 + # - name: Install dependencies & run tests + # run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' + # using Pkg; + # Pkg.Registry.update(); + # Pkg.develop(path="./DifferentiationInterface"); + # if ENV["JULIA_DI_PR_DRAFT"] == "true"; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); + # else; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); + # end;' + # - uses: julia-actions/julia-processcoverage@v1 + # with: + # directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test + # - uses: codecov/codecov-action@v5 + # with: + # files: lcov.info + # flags: DIT + # token: ${{ secrets.CODECOV_TOKEN }} + # fail_ci_if_error: false diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index ded9bd6c3..1cbea565f 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -21,6 +21,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" @@ -46,6 +47,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = [ "ForwardDiff", "DiffResults", ] +DifferentiationInterfaceReactantExt = ["Reactant", "Enzyme"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" @@ -56,7 +58,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.18.0" +ADTypes = "1.19.0" ChainRulesCore = "1.23.0" DiffResults = "1.1.0" Diffractor = "=0.2.6" @@ -71,6 +73,7 @@ GTPSA = "1.4.0" LinearAlgebra = "1" Mooncake = "0.4.175" PolyesterForwardDiff = "0.1.2" +Reactant = "0.2.178" ReverseDiff = "1.15.1" SparseArrays = "1" SparseConnectivityTracer = "0.6.14, 1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl new file mode 100644 index 000000000..9aefb9381 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl @@ -0,0 +1,12 @@ +module DifferentiationInterfaceReactantExt + +using ADTypes: ADTypes, AutoReactant +import DifferentiationInterface as DI +using Reactant: @compile, to_rarray + +DI.check_available(backend::AutoReactant) = DI.check_available(backend.mode) +DI.inplace_support(backend::AutoReactant) = DI.inplace_support(backend.mode) + +include("onearg.jl") + +end # module diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl new file mode 100644 index 000000000..c7ebce528 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl @@ -0,0 +1,81 @@ +struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{SIG} + _sig::Val{SIG} + xr::XR + gr::GR + compiled_gradient::CG + compiled_gradient!::CG! + compiled_value_and_gradient::CVG + compiled_value_and_gradient!::CVG! +end + +function DI.prepare_gradient_nokwarg(strict::Val, f::F, rebackend::AutoReactant, x) where {F} + _sig = DI.signature(f, rebackend, x; strict) + backend = rebackend.mode + xr = to_rarray(x) + gr = to_rarray(similar(x)) + _gradient(_xr) = DI.gradient(f, backend, _xr) + _gradient!(_gr, _xr) = copy!(_gr, DI.gradient(f, backend, _xr)) + _value_and_gradient(_xr) = DI.value_and_gradient(f, backend, _xr) + function _value_and_gradient!(_gr, _xr) + y, __gr = DI.value_and_gradient(f, backend, _xr) + copy!(_gr, __gr) + return y, _gr + end + compiled_gradient = @compile _gradient(xr) + compiled_gradient! = @compile _gradient!(gr, xr) + compiled_value_and_gradient = @compile _value_and_gradient(xr) + compiled_value_and_gradient! = @compile _value_and_gradient!(gr, xr) + return ReactantGradientPrep( + _sig, + xr, + gr, + compiled_gradient, + compiled_gradient!, + compiled_value_and_gradient, + compiled_value_and_gradient!, + ) +end + +function DI.gradient( + f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x + ) where {F} + DI.check_prep(f, prep, rebackend, x) + (; xr, compiled_gradient) = prep + copy!(xr, x) + gr = compiled_gradient(xr) + g = convert(typeof(x), gr) + return g +end + +function DI.value_and_gradient( + f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x + ) where {F} + DI.check_prep(f, prep, rebackend, x) + (; xr, compiled_value_and_gradient) = prep + copy!(xr, x) + yr, gr = compiled_value_and_gradient(xr) + y = convert(eltype(x), yr) + g = convert(typeof(x), gr) + return y, g +end + +function DI.gradient!( + f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x + ) where {F} + DI.check_prep(f, prep, rebackend, x) + (; xr, gr, compiled_gradient!) = prep + copy!(xr, x) + compiled_gradient!(gr, xr) + return copy!(grad, gr) +end + +function DI.value_and_gradient!( + f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x + ) where {F} + DI.check_prep(f, prep, rebackend, x) + (; xr, gr, compiled_value_and_gradient!) = prep + copy!(xr, x) + yr, gr = compiled_value_and_gradient!(gr, xr) + y = convert(eltype(x), yr) + return y, copy!(grad, gr) +end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 4a07e7301..393b9051a 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -30,6 +30,7 @@ using ADTypes: AutoMooncake, AutoMooncakeForward, AutoPolyesterForwardDiff, + AutoReactant, AutoReverseDiff, AutoSymbolics, AutoTracker, @@ -118,6 +119,7 @@ export AutoGTPSA export AutoMooncake export AutoMooncakeForward export AutoPolyesterForwardDiff +export AutoReactant export AutoReverseDiff export AutoSymbolics export AutoTracker diff --git a/DifferentiationInterface/test/Back/Reactant/test.jl b/DifferentiationInterface/test/Back/Reactant/test.jl new file mode 100644 index 000000000..21e1138a0 --- /dev/null +++ b/DifferentiationInterface/test/Back/Reactant/test.jl @@ -0,0 +1,14 @@ +using Pkg +Pkg.add("Reactant") + +using DifferentiationInterface +using DifferentiationInterfaceTest +using Reactant + +backend = AutoReactant() + +test_differentiation( + backend, DifferentiationInterfaceTest.default_scenarios(); + excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback), + logging = true +) From c7e75985c31405f512681ee3d40d14f7b4d08ed1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Nov 2025 10:35:45 +0100 Subject: [PATCH 2/5] Include contexts --- .../docs/src/explanation/backends.md | 2 + .../DifferentiationInterfaceReactantExt.jl | 3 +- .../onearg.jl | 77 +++++++++---------- .../utils.jl | 7 ++ .../test/Back/Reactant/test.jl | 11 ++- DifferentiationInterfaceTest/Project.toml | 2 + .../src/DifferentiationInterfaceTest.jl | 1 + .../src/scenarios/modify.jl | 4 +- 8 files changed, 63 insertions(+), 44 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 0da5201a9..58abbfa09 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -19,6 +19,8 @@ We support the following dense backend choices from [ADTypes.jl](https://github. - [`AutoTracker`](@extref ADTypes.AutoTracker) - [`AutoZygote`](@extref ADTypes.AutoZygote) +In addition, we provide experimental support for [`AutoReactant`](@extref ADTypes.AutoReactant), sofar only for [`gradient`](@ref) and its variants. + ## Features Given a backend object, you can use: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl index 9aefb9381..7740b53a6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl @@ -2,11 +2,12 @@ module DifferentiationInterfaceReactantExt using ADTypes: ADTypes, AutoReactant import DifferentiationInterface as DI -using Reactant: @compile, to_rarray +using Reactant: @compile, ConcreteRArray, ConcreteRNumber, to_rarray DI.check_available(backend::AutoReactant) = DI.check_available(backend.mode) DI.inplace_support(backend::AutoReactant) = DI.inplace_support(backend.mode) +include("utils.jl") include("onearg.jl") end # module diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl index c7ebce528..af82235a8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl @@ -8,23 +8,18 @@ struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{ compiled_value_and_gradient!::CVG! end -function DI.prepare_gradient_nokwarg(strict::Val, f::F, rebackend::AutoReactant, x) where {F} +function DI.prepare_gradient_nokwarg( + strict::Val, f::F, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} _sig = DI.signature(f, rebackend, x; strict) backend = rebackend.mode - xr = to_rarray(x) - gr = to_rarray(similar(x)) - _gradient(_xr) = DI.gradient(f, backend, _xr) - _gradient!(_gr, _xr) = copy!(_gr, DI.gradient(f, backend, _xr)) - _value_and_gradient(_xr) = DI.value_and_gradient(f, backend, _xr) - function _value_and_gradient!(_gr, _xr) - y, __gr = DI.value_and_gradient(f, backend, _xr) - copy!(_gr, __gr) - return y, _gr - end - compiled_gradient = @compile _gradient(xr) - compiled_gradient! = @compile _gradient!(gr, xr) - compiled_value_and_gradient = @compile _value_and_gradient(xr) - compiled_value_and_gradient! = @compile _value_and_gradient!(gr, xr) + xr = to_reac(x) + gr = to_reac(similar(x)) + contextsr = map(to_reac, contexts) + compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...) + compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...) + compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...) + compiled_value_and_gradient! = @compile DI.value_and_gradient!(f, gr, backend, xr, contextsr...) return ReactantGradientPrep( _sig, xr, @@ -37,45 +32,49 @@ function DI.prepare_gradient_nokwarg(strict::Val, f::F, rebackend::AutoReactant, end function DI.gradient( - f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x - ) where {F} + f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode (; xr, compiled_gradient) = prep - copy!(xr, x) - gr = compiled_gradient(xr) - g = convert(typeof(x), gr) - return g + copyto!(xr, x) + contextsr = map(to_reac, contexts) + gr = compiled_gradient(f, backend, xr, contextsr...) + return gr end function DI.value_and_gradient( - f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x - ) where {F} + f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode (; xr, compiled_value_and_gradient) = prep - copy!(xr, x) - yr, gr = compiled_value_and_gradient(xr) - y = convert(eltype(x), yr) - g = convert(typeof(x), gr) - return y, g + copyto!(xr, x) + contextsr = map(to_reac, contexts) + yr, gr = compiled_value_and_gradient(f, backend, xr, contextsr...) + return yr, gr end function DI.gradient!( - f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x - ) where {F} + f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode (; xr, gr, compiled_gradient!) = prep - copy!(xr, x) - compiled_gradient!(gr, xr) - return copy!(grad, gr) + copyto!(xr, x) + contextsr = map(to_reac, contexts) + compiled_gradient!(f, gr, backend, xr, contextsr...) + return copyto!(grad, gr) end function DI.value_and_gradient!( - f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x - ) where {F} + f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} + ) where {F, C} DI.check_prep(f, prep, rebackend, x) + backend = rebackend.mode (; xr, gr, compiled_value_and_gradient!) = prep - copy!(xr, x) - yr, gr = compiled_value_and_gradient!(gr, xr) - y = convert(eltype(x), yr) - return y, copy!(grad, gr) + copyto!(xr, x) + contextsr = map(to_reac, contexts) + yr, gr = compiled_value_and_gradient!(f, gr, backend, xr, contextsr...) + return yr, copyto!(grad, gr) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl new file mode 100644 index 000000000..7c22e9a4c --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl @@ -0,0 +1,7 @@ +to_reac(x::AbstractArray) = to_rarray(x) +to_reac(x::ConcreteRArray) = x +to_reac(x::Number) = ConcreteRNumber(x) +to_reac(x::ConcreteRNumber) = x + +to_reac(c::DI.Constant) = DI.Constant(to_reac(DI.unwrap(c))) +to_reac(c::DI.Cache) = DI.Cache(to_reac(DI.unwrap(c))) diff --git a/DifferentiationInterface/test/Back/Reactant/test.jl b/DifferentiationInterface/test/Back/Reactant/test.jl index 21e1138a0..b0640498a 100644 --- a/DifferentiationInterface/test/Back/Reactant/test.jl +++ b/DifferentiationInterface/test/Back/Reactant/test.jl @@ -1,14 +1,21 @@ using Pkg +Pkg.add(url = "https://github.com/EnzymeAD/Enzyme.jl") Pkg.add("Reactant") using DifferentiationInterface using DifferentiationInterfaceTest using Reactant +using Test backend = AutoReactant() +@test check_available(backend) +@test check_inplace(backend) + test_differentiation( - backend, DifferentiationInterfaceTest.default_scenarios(); + backend, DifferentiationInterfaceTest.default_scenarios(; + include_constantified = true, include_cachified = false + ); excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback), - logging = true + logging = false ) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index a417d9eac..cb718f78c 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -10,6 +10,7 @@ Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -40,6 +41,7 @@ DataFrames = "1.6.1" DifferentiationInterface = "0.7.7" DocStringExtensions = "0.8,0.9" ForwardDiff = "0.10.36,1" +GPUArraysCore = "0.2.0" JET = "0.9,0.10,0.11" JLArrays = "0.1,0.2,0.3" LinearAlgebra = "1" diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 51b216915..84d9d2b38 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -92,6 +92,7 @@ using DifferentiationInterface: using DifferentiationInterface: Rewrap, Context, Constant, Cache, ConstantOrCache, unwrap using DifferentiationInterface: PreparationMismatchError using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES +using GPUArraysCore: @allowscalar using JET: @test_opt using LinearAlgebra: Adjoint, Diagonal, Transpose, I, dot, parent using PrecompileTools: @compile_workload diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 5761eb346..bef174384 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -224,8 +224,8 @@ Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") function (sc::StoreInCache{:out})(x, y_cache) # no annotation otherwise Zygote.Buffer cries y = sc.f(x) if y isa Number - y_cache[1] = y - return y_cache[1] + @allowscalar y_cache[1] = y + return @allowscalar y_cache[1] else copyto!(y_cache, y) return copy(y_cache) From 8ca970ef0775ac2bb6eb84e51a23dbe5efe6cb15 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:19:42 +0100 Subject: [PATCH 3/5] Better docs --- DifferentiationInterface/docs/make.jl | 1 + DifferentiationInterface/docs/src/api.md | 6 ++ .../docs/src/explanation/backends.md | 4 + .../onearg.jl | 40 +++++---- .../utils.jl | 10 +-- .../src/DifferentiationInterface.jl | 2 + DifferentiationInterface/src/misc/reactant.jl | 84 +++++++++++++++++++ .../test/Back/Reactant/Project.toml | 10 +++ .../test/Back/Reactant/test.jl | 14 ++-- 9 files changed, 138 insertions(+), 33 deletions(-) create mode 100644 DifferentiationInterface/src/misc/reactant.jl create mode 100644 DifferentiationInterface/test/Back/Reactant/Project.toml diff --git a/DifferentiationInterface/docs/make.jl b/DifferentiationInterface/docs/make.jl index 334b5b83e..45ae93590 100644 --- a/DifferentiationInterface/docs/make.jl +++ b/DifferentiationInterface/docs/make.jl @@ -10,6 +10,7 @@ using Zygote: Zygote links = InterLinks( "ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/", + "Reactant" => "https://enzymead.github.io/Reactant.jl/stable/", "SparseConnectivityTracer" => "https://adrianhill.de/SparseConnectivityTracer.jl/stable/", "SparseMatrixColorings" => "https://gdalle.github.io/SparseMatrixColorings.jl/stable/", "Symbolics" => "https://symbolics.juliasymbolics.org/stable/", diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 108f1c03d..08fd8282f 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -144,3 +144,9 @@ DifferentiationInterface.AutoReverseFromPrimitive ```@docs DifferentiationInterface.Prep ``` + +### Reactant + +```@docs +DifferentiationInterface.to_reactant +``` diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 58abbfa09..3c6eb66a9 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -169,6 +169,10 @@ If a GTPSA [`Descriptor`](https://bmad-sim.github.io/GTPSA.jl/stable/man/b_descr Most operators fall back on `AutoForwardDiff`. +### Reactant + +See the docstring for [`AutoReactant`](@ref). + ### ReverseDiff With `AutoReverseDiff(compile=false)`, preparation preallocates a [config](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractConfig-API). diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl index af82235a8..d0e4acebd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl @@ -13,9 +13,9 @@ function DI.prepare_gradient_nokwarg( ) where {F, C} _sig = DI.signature(f, rebackend, x; strict) backend = rebackend.mode - xr = to_reac(x) - gr = to_reac(similar(x)) - contextsr = map(to_reac, contexts) + xr = x isa ConcreteRArray ? nothing : ConcreteRArray(x) + gr = x isa ConcreteRArray ? nothing : ConcreteRArray(similar(x)) + contextsr = map(_to_reactant, contexts) compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...) compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...) compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...) @@ -36,10 +36,9 @@ function DI.gradient( ) where {F, C} DI.check_prep(f, prep, rebackend, x) backend = rebackend.mode - (; xr, compiled_gradient) = prep - copyto!(xr, x) - contextsr = map(to_reac, contexts) - gr = compiled_gradient(f, backend, xr, contextsr...) + xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) + contextsr = map(_to_reactant, contexts) + gr = prep.compiled_gradient(f, backend, xr, contextsr...) return gr end @@ -48,10 +47,9 @@ function DI.value_and_gradient( ) where {F, C} DI.check_prep(f, prep, rebackend, x) backend = rebackend.mode - (; xr, compiled_value_and_gradient) = prep - copyto!(xr, x) - contextsr = map(to_reac, contexts) - yr, gr = compiled_value_and_gradient(f, backend, xr, contextsr...) + xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) + contextsr = map(_to_reactant, contexts) + yr, gr = prep.compiled_value_and_gradient(f, backend, xr, contextsr...) return yr, gr end @@ -60,11 +58,11 @@ function DI.gradient!( ) where {F, C} DI.check_prep(f, prep, rebackend, x) backend = rebackend.mode - (; xr, gr, compiled_gradient!) = prep - copyto!(xr, x) - contextsr = map(to_reac, contexts) - compiled_gradient!(f, gr, backend, xr, contextsr...) - return copyto!(grad, gr) + xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) + gr = isnothing(prep.gr) ? grad : prep.gr + contextsr = map(_to_reactant, contexts) + prep.compiled_gradient!(f, gr, backend, xr, contextsr...) + return isnothing(prep.gr) ? grad : copyto!(grad, gr) end function DI.value_and_gradient!( @@ -72,9 +70,9 @@ function DI.value_and_gradient!( ) where {F, C} DI.check_prep(f, prep, rebackend, x) backend = rebackend.mode - (; xr, gr, compiled_value_and_gradient!) = prep - copyto!(xr, x) - contextsr = map(to_reac, contexts) - yr, gr = compiled_value_and_gradient!(f, gr, backend, xr, contextsr...) - return yr, copyto!(grad, gr) + xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x) + gr = isnothing(prep.gr) ? grad : prep.gr + contextsr = map(_to_reactant, contexts) + yr, gr = prep.compiled_value_and_gradient!(f, gr, backend, xr, contextsr...) + return yr, isnothing(prep.gr) ? grad : copyto!(grad, gr) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl index 7c22e9a4c..d1579f975 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/utils.jl @@ -1,7 +1,3 @@ -to_reac(x::AbstractArray) = to_rarray(x) -to_reac(x::ConcreteRArray) = x -to_reac(x::Number) = ConcreteRNumber(x) -to_reac(x::ConcreteRNumber) = x - -to_reac(c::DI.Constant) = DI.Constant(to_reac(DI.unwrap(c))) -to_reac(c::DI.Cache) = DI.Cache(to_reac(DI.unwrap(c))) +_to_reactant(x) = DI.to_reactant(x) +_to_reactant(c::DI.Constant) = DI.Constant(_to_reactant(DI.unwrap(c))) +_to_reactant(c::DI.Cache) = DI.Cache(_to_reactant(DI.unwrap(c))) diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 393b9051a..a3cddd305 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -69,6 +69,7 @@ include("misc/sparsity_detector.jl") include("misc/simple_finite_diff.jl") include("misc/zero_backends.jl") include("misc/overloading.jl") +include("misc/reactant.jl") ## Exported @@ -132,6 +133,7 @@ export AutoSparse @public inner, outer @public AutoForwardFromPrimitive, AutoReverseFromPrimitive @public Prep +@public to_reactant include("init.jl") diff --git a/DifferentiationInterface/src/misc/reactant.jl b/DifferentiationInterface/src/misc/reactant.jl new file mode 100644 index 000000000..0d0ba9219 --- /dev/null +++ b/DifferentiationInterface/src/misc/reactant.jl @@ -0,0 +1,84 @@ +""" +!!! tip "DI-specific information" + This part of the docstring is related to the use of `AutoReactant` inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), or DI for short. + Reactant's tutorial on [partial evaluation](https://enzymead.github.io/Reactant.jl/stable/tutorials/partial-evaluation) is useful reading to understand what follows. + +The `AutoReactant` backend inside DI imposes the following restrictions / assumptions: + +- The only supported operator (at the moment) is `DI.gradient` (along with its variants). +- The input `x` must be an `AbstractArray` such that `Reactant.ConcreteRArray(x)` is well-defined. +- By default, contexts such as `DI.Constant` and `DI.Cache` will be partially evaluated inside the compiled differentiation operator at preparation time. This means that the context value provided at preparation will be reused at every subsequent execution, while the context value provided at execution will be ignored. In particular, `DI.Cache` contexts will usually error and `DI.Constant` contexts will be frozen to one value. + +To disable partial evaluation and enforce tracing of contexts instead, first wrap them into types that _you own_. +Then, overload [`DifferentiationInterface.to_reactant`](@ref) on these types to perform tracing in the way you see fit, for instance with `Reactant.to_rarray`. +Every value you choose not to trace will still be partially evaluated at preparation time. + +# Example + +```jldoctest +using DifferentiationInterface +import DifferentiationInterface as DI +import Reactant + +struct MyArgument{T1 <: Number, T2 <: AbstractArray} + u::T1 + v::T2 +end + +f(x, a::MyArgument) = a.u * sum(a.v .* x .^ 2) + +DI.to_reactant(a::MyArgument) = Reactant.to_rarray(a; track_numbers = false) + +# preparation time +x0 = zeros(2) +a0 = MyArgument(1.0, [2.0, 3.0]) + +# execution time +x = [4.0, 5.0] +a = MyArgument(6.0, [7.0, 8.0]) + +backend = AutoReactant() +prep = prepare_gradient(f, backend, x0, Constant(a0)); + +g = gradient(f, prep, backend, x, Constant(a)) +g ≈ a0.u * 2 * (a.v .* x) # a0.u is partially evaluated, a0.v is traced + +# output + +true +``` +""" +AutoReactant + +""" + to_reactant(a) + +Convert an argument `a` to an object `ar` containing the same values, where all the fields and subfields that can contain active (differentiated) data have been translated to [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl) types such as [`ConcreteRArray`](@extref Reactant.ConcreteRArray) or [`ConcreteRNumber`](@extref Reactant.ConcreteRNumber). + +!!! danger + DifferentiationInterface.jl implements this function as the identity, on purpose. + It should not be overloaded on base types, but only on types that you own, to modify the default behavior of `AutoReactant`. + +# Example + +```jldoctest +import DifferentiationInterface as DI +import Reactant + +struct MyArgument{T1 <: Number, T2 <: AbstractArray} + u::T1 + v::T2 +end + +DI.to_reactant(a::MyArgument) = Reactant.to_rarray(a; track_numbers = false) + +a = MyArgument(1.0, [2.0, 3.0]) +ar = DI.to_reactant(a) +ar isa MyArgument{Float64, <:Reactant.ConcreteRArray} + +# output + +true +``` +""" +to_reactant(x) = x diff --git a/DifferentiationInterface/test/Back/Reactant/Project.toml b/DifferentiationInterface/test/Back/Reactant/Project.toml new file mode 100644 index 000000000..9ec59dc7b --- /dev/null +++ b/DifferentiationInterface/test/Back/Reactant/Project.toml @@ -0,0 +1,10 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/DifferentiationInterface/test/Back/Reactant/test.jl b/DifferentiationInterface/test/Back/Reactant/test.jl index b0640498a..cbdecfcc7 100644 --- a/DifferentiationInterface/test/Back/Reactant/test.jl +++ b/DifferentiationInterface/test/Back/Reactant/test.jl @@ -1,10 +1,10 @@ -using Pkg -Pkg.add(url = "https://github.com/EnzymeAD/Enzyme.jl") -Pkg.add("Reactant") +include("../../testutils.jl") using DifferentiationInterface +import DifferentiationInterface as DI using DifferentiationInterfaceTest -using Reactant +import DifferentiationInterfaceTest as DIT +import Enzyme, Reactant using Test backend = AutoReactant() @@ -12,9 +12,13 @@ backend = AutoReactant() @test check_available(backend) @test check_inplace(backend) +scen1 = DIT.Scenario( + +) + test_differentiation( backend, DifferentiationInterfaceTest.default_scenarios(; - include_constantified = true, include_cachified = false + include_constantified = false, include_cachified = false ); excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback), logging = false From a3c6aa55751996b336306cd6ec4f0b44a69938a3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:05:14 +0100 Subject: [PATCH 4/5] Fixes --- DifferentiationInterface/docs/Project.toml | 7 ++++--- DifferentiationInterface/test/Back/Reactant/test.jl | 4 ---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/DifferentiationInterface/docs/Project.toml b/DifferentiationInterface/docs/Project.toml index 24bc8895a..67241f1fc 100644 --- a/DifferentiationInterface/docs/Project.toml +++ b/DifferentiationInterface/docs/Project.toml @@ -9,10 +9,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources] +DifferentiationInterface = {path = ".."} + [compat] ADTypes = "1" BenchmarkTools = "1" @@ -26,6 +30,3 @@ SparseConnectivityTracer = "1.1.2" SparseMatrixColorings = "0.4.23" Zygote = "0.7.10" julia = "1.10.10" - -[sources] -DifferentiationInterface = { path = ".." } diff --git a/DifferentiationInterface/test/Back/Reactant/test.jl b/DifferentiationInterface/test/Back/Reactant/test.jl index cbdecfcc7..80dd7f6d1 100644 --- a/DifferentiationInterface/test/Back/Reactant/test.jl +++ b/DifferentiationInterface/test/Back/Reactant/test.jl @@ -12,10 +12,6 @@ backend = AutoReactant() @test check_available(backend) @test check_inplace(backend) -scen1 = DIT.Scenario( - -) - test_differentiation( backend, DifferentiationInterfaceTest.default_scenarios(; include_constantified = false, include_cachified = false From f08d4702f34cd3ccf15c8910c755e3b915784f24 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:18:54 +0100 Subject: [PATCH 5/5] Dcocs --- DifferentiationInterface/docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 08fd8282f..0d03227aa 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -148,5 +148,6 @@ DifferentiationInterface.Prep ### Reactant ```@docs +AutoReactant DifferentiationInterface.to_reactant ```