Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 01a5e5f

Browse files
Merge pull request #286 from avik-pal/ap/format
Run Formatter
2 parents 4f9aa7d + b31d02c commit 01a5e5f

12 files changed

+83
-59
lines changed

ext/SparseDiffToolsEnzymeExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module SparseDiffToolsEnzymeExt
22

33
import ArrayInterface: fast_scalar_indexing
44
import SparseDiffTools: __f̂, __maybe_copy_x, __jacobian!, __gradient, __gradient!,
5-
AutoSparseEnzyme, __test_backend_loaded
5+
AutoSparseEnzyme, __test_backend_loaded
66
# FIXME: For Enzyme we currently assume reverse mode
77
import ADTypes: AutoEnzyme
88
using Enzyme

ext/SparseDiffToolsPolyesterForwardDiffExt.jl

+11-6
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module SparseDiffToolsPolyesterForwardDiffExt
33
using ADTypes, SparseDiffTools, PolyesterForwardDiff
44
import ForwardDiff
55
import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection,
6-
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache, sparse_jacobian!,
7-
sparse_jacobian_static_array, __standard_tag, __chunksize
6+
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache,
7+
sparse_jacobian!,
8+
sparse_jacobian_static_array, __standard_tag, __chunksize
89

910
struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
1011
AbstractMaybeSparseJacobianCache
@@ -15,8 +16,10 @@ struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
1516
x::X
1617
end
1718

18-
function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,
19-
AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f::F, x;
19+
function sparse_jacobian_cache(
20+
ad::Union{AutoSparsePolyesterForwardDiff,
21+
AutoPolyesterForwardDiff},
22+
sd::AbstractMaybeSparsityDetection, f::F, x;
2023
fx = nothing) where {F}
2124
coloring_result = sd(ad, f, x)
2225
fx = fx === nothing ? similar(f(x)) : fx
@@ -35,8 +38,10 @@ function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,
3538
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
3639
end
3740

38-
function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,
39-
AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f!::F, fx,
41+
function sparse_jacobian_cache(
42+
ad::Union{AutoSparsePolyesterForwardDiff,
43+
AutoPolyesterForwardDiff},
44+
sd::AbstractMaybeSparsityDetection, f!::F, fx,
4045
x) where {F}
4146
coloring_result = sd(ad, f!, fx, x)
4247
if coloring_result isa NoMatrixColoring

ext/SparseDiffToolsZygoteExt.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ import Setfield: @set!
88
import Tricks: static_hasmethod
99

1010
import SparseDiffTools: numback_hesvec!,
11-
numback_hesvec, autoback_hesvec!, autoback_hesvec, auto_vecjac!, auto_vecjac
11+
numback_hesvec, autoback_hesvec!, autoback_hesvec, auto_vecjac!,
12+
auto_vecjac
1213
import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!
1314
import ADTypes: AutoZygote, AutoSparseZygote
1415

src/SparseDiffTools.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ import Graphs: SimpleGraph
1010
using FiniteDiff, ForwardDiff
1111
@reexport using ADTypes
1212
import ADTypes: AbstractADType, AutoSparseZygote, AbstractSparseForwardMode,
13-
AbstractSparseReverseMode, AbstractSparseFiniteDifferences, AbstractReverseMode
13+
AbstractSparseReverseMode, AbstractSparseFiniteDifferences,
14+
AbstractReverseMode
1415
import ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
1516
# Array Packages
1617
using ArrayInterface, SparseArrays
@@ -66,21 +67,22 @@ function auto_vecjac! end
6667

6768
# Coloring Algorithms
6869
export AcyclicColoring,
69-
BacktrackingColor, ContractionColor, GreedyD1Color, GreedyStar1Color, GreedyStar2Color
70+
BacktrackingColor, ContractionColor, GreedyD1Color, GreedyStar1Color,
71+
GreedyStar2Color
7072
export matrix2graph, matrix_colors
7173
# Sparse Jacobian Computation
7274
export ForwardColorJacCache, forwarddiff_color_jacobian, forwarddiff_color_jacobian!
7375
# Sparse Hessian Computation
7476
export numauto_color_hessian, numauto_color_hessian!, autoauto_color_hessian,
75-
autoauto_color_hessian!, ForwardAutoColorHesCache, ForwardColorHesCache
77+
autoauto_color_hessian!, ForwardAutoColorHesCache, ForwardColorHesCache
7678
# JacVec Products
7779
export auto_jacvec, auto_jacvec!, num_jacvec, num_jacvec!
7880
# VecJac Products
7981
export num_vecjac, num_vecjac!, auto_vecjac, auto_vecjac!
8082
# HesVec Products
8183
export numauto_hesvec,
82-
numauto_hesvec!, autonum_hesvec, autonum_hesvec!, numback_hesvec, numback_hesvec!,
83-
num_hesvec, num_hesvec!, autoback_hesvec, autoback_hesvec!
84+
numauto_hesvec!, autonum_hesvec, autonum_hesvec!, numback_hesvec, numback_hesvec!,
85+
num_hesvec, num_hesvec!, autoback_hesvec, autoback_hesvec!
8486
# HesVecGrad Products
8587
export num_hesvecgrad, num_hesvecgrad!, auto_hesvecgrad, auto_hesvecgrad!
8688
# Operators
@@ -91,7 +93,7 @@ export update_coefficients, update_coefficients!, value!
9193
export AutoSparseEnzyme
9294

9395
export NoSparsityDetection, SymbolicsSparsityDetection, JacPrototypeSparsityDetection,
94-
PrecomputedJacobianColorvec, ApproximateJacobianSparsity, AutoSparsityDetection
96+
PrecomputedJacobianColorvec, ApproximateJacobianSparsity, AutoSparsityDetection
9597
export sparse_jacobian, sparse_jacobian_cache, sparse_jacobian!
9698
export init_jacobian
9799

src/differentiation/common.jl

+12-10
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,20 @@ function JacFunctionWrapper(f::F, fu_, u, p, t;
5353
oop = static_hasmethod(f, typeof((u,)))
5454
if iip || oop
5555
if p !== nothing || t !== nothing
56-
Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we
57-
potentially detected `f(du, u)` or `f(u)`. This can be caused by:
56+
Base.depwarn(
57+
"""`p` and/or `t` provided and are not `nothing`. But we
58+
potentially detected `f(du, u)` or `f(u)`. This can be caused by:
5859
59-
1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not
60-
be supplied.
61-
2. `f(args...)` is defined, in which case `hasmethod` can be spurious.
60+
1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not
61+
be supplied.
62+
2. `f(args...)` is defined, in which case `hasmethod` can be spurious.
6263
63-
Currently, we perform the check for `f(du, u)` and `f(u)` first, but in
64-
future breaking releases, this check will be performed last, which means
65-
that if `t` is provided `f(du, u, p, t)`/`f(u, p, t)` will be given
66-
precedence, similarly if `p` is provided `f(du, u, p)`/`f(u, p)` will be
67-
given precedence.""", :JacFunctionWrapper)
64+
Currently, we perform the check for `f(du, u)` and `f(u)` first, but in
65+
future breaking releases, this check will be performed last, which means
66+
that if `t` is provided `f(du, u, p, t)`/`f(u, p, t)` will be given
67+
precedence, similarly if `p` is provided `f(du, u, p)`/`f(u, p)` will be
68+
given precedence.""",
69+
:JacFunctionWrapper)
6870
end
6971
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
7072
fu, p, t)

src/differentiation/compute_hessian_ad.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ function ForwardColorHesCache(f, x::AbstractVector{<:Number},
4141
if sparsity === nothing
4242
sparsity = sparse(ones(length(x), length(x)))
4343
end
44-
return ForwardColorHesCache(sparsity, colorvec, ncolors, D, buffer, g1!, grad_config, G,
44+
return ForwardColorHesCache(
45+
sparsity, colorvec, ncolors, D, buffer, g1!, grad_config, G,
4546
G2)
4647
end
4748

@@ -123,12 +124,14 @@ function ForwardAutoColorHesCache(f, x::AbstractVector{V},
123124
return ForwardAutoColorHesCache(jac_cache, g!, sparsity, colorvec)
124125
end
125126

126-
function autoauto_color_hessian!(H::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
127+
function autoauto_color_hessian!(
128+
H::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
127129
hes_cache::ForwardAutoColorHesCache)
128130
forwarddiff_color_jacobian!(H, hes_cache.grad!, x, hes_cache.jac_cache)
129131
end
130132

131-
function autoauto_color_hessian!(H::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
133+
function autoauto_color_hessian!(
134+
H::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
132135
colorvec::AbstractVector{<:Integer} = eachindex(x),
133136
sparsity::Union{AbstractMatrix, Nothing} = nothing)
134137
hes_cache = ForwardAutoColorHesCache(f, x, colorvec, sparsity)

src/differentiation/compute_jacobian_ad.jl

+11-8
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function ForwardColorJacCache(f::F, x, _chunksize = nothing; dx = nothing, tag =
4242
_t = Dual{
4343
T,
4444
eltype(x),
45-
getsize(chunksize),
45+
getsize(chunksize)
4646
}.(vec(x), ForwardDiff.Partials.(first(p)))
4747
t = ArrayInterface.restructure(x, _t)
4848
end
@@ -54,7 +54,8 @@ function ForwardColorJacCache(f::F, x, _chunksize = nothing; dx = nothing, tag =
5454
tup = ArrayInterface.allowed_getindex(ArrayInterface.allowed_getindex(p, 1),
5555
1) .* false
5656
_pi = adapt(parameterless_type(dx), [tup for i in 1:length(dx)])
57-
fx = reshape(Dual{T, eltype(dx), length(tup)}.(vec(dx), ForwardDiff.Partials.(_pi)),
57+
fx = reshape(
58+
Dual{T, eltype(dx), length(tup)}.(vec(dx), ForwardDiff.Partials.(_pi)),
5859
size(dx)...)
5960
_dx = dx
6061
end
@@ -204,7 +205,7 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
204205
dx = vec(partials.(fx, j))
205206
pick_inds = [i
206207
for i in 1:length(rows_index)
207-
if colorvec[cols_index[i]] == color_i]
208+
if colorvec[cols_index[i]] == color_i]
208209
rows_index_c = rows_index[pick_inds]
209210
cols_index_c = cols_index[pick_inds]
210211
if J isa SparseMatrixCSC || j > 1
@@ -232,8 +233,9 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
232233
for j in 1:chunksize
233234
col_index = (i - 1) * chunksize + j
234235
(col_index > ncols) && return J
235-
Ji = mapreduce(i -> i == col_index ? partials.(vec(fx), j) :
236-
adapt(parameterless_type(J), zeros(eltype(J), nrows)),
236+
Ji = mapreduce(
237+
i -> i == col_index ? partials.(vec(fx), j) :
238+
adapt(parameterless_type(J), zeros(eltype(J), nrows)),
237239
hcat, 1:ncols)
238240
if j == 1 && i == 1
239241
J .= (size(Ji) != size(J) ? reshape(Ji, size(J)) : Ji) # overwrite pre-allocated matrix
@@ -281,7 +283,7 @@ function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
281283
dx = vec(partials.(fx, j))
282284
pick_inds = [i
283285
for i in 1:length(rows_index)
284-
if colorvec[cols_index[i]] == color_i]
286+
if colorvec[cols_index[i]] == color_i]
285287
rows_index_c = rows_index[pick_inds]
286288
cols_index_c = cols_index[pick_inds]
287289
if J isa SparseMatrixCSC
@@ -302,8 +304,9 @@ function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
302304
for j in 1:chunksize
303305
col_index = (i - 1) * chunksize + j
304306
(col_index > ncols) && return J
305-
Ji = mapreduce(i -> i == col_index ? partials.(vec(fx), j) :
306-
adapt(parameterless_type(J), zeros(eltype(J), nrows)),
307+
Ji = mapreduce(
308+
i -> i == col_index ? partials.(vec(fx), j) :
309+
adapt(parameterless_type(J), zeros(eltype(J), nrows)),
307310
hcat, 1:ncols)
308311
J = J + (size(Ji) != size(J) ? reshape(Ji, size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
309312
end

src/differentiation/jaches_products.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ get_tag(::Dual{T, V, N}) where {T, V, N} = T
66
# J(f(x))*v
77
function auto_jacvec!(dy, f, x, v,
88
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
9-
eltype(x), 1,
9+
eltype(x), 1
1010
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
1111
cache2 = similar(cache1))
1212
cache1 .= Dual{
1313
get_tag(cache1),
1414
eltype(x),
15-
1,
15+
1
1616
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
1717
f(cache2, cache1)
1818
vecdy = _vec(dy)
@@ -27,7 +27,7 @@ function auto_jacvec(f, x, v)
2727
y = ForwardDiff.Dual{
2828
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
2929
eltype(x),
30-
1,
30+
1
3131
}.(x, ForwardDiff.Partials.(tuple.(vv)))
3232
vec(partials.(vec(f(y)), 1))
3333
end
@@ -113,17 +113,17 @@ end
113113

114114
function autonum_hesvec!(dy, f, x, v,
115115
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
116-
eltype(x), 1,
116+
eltype(x), 1
117117
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
118118
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
119-
eltype(x), 1,
119+
eltype(x), 1
120120
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))))
121121
cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central})
122122
g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache)
123123
cache1 .= Dual{
124124
get_tag(cache1),
125125
eltype(x),
126-
1,
126+
1
127127
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
128128
g(cache2, cache1)
129129
dy .= partials.(cache2, 1)
@@ -159,15 +159,15 @@ end
159159

160160
function auto_hesvecgrad!(dy, g, x, v,
161161
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
162-
eltype(x), 1,
162+
eltype(x), 1
163163
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
164164
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
165-
eltype(x), 1,
165+
eltype(x), 1
166166
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x))))))
167167
cache2 .= Dual{
168168
get_tag(cache2),
169169
eltype(x),
170-
1,
170+
1
171171
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
172172
g(cache3, cache2)
173173
dy .= partials.(cache3, 1)
@@ -177,7 +177,7 @@ function auto_hesvecgrad(g, x, v)
177177
y = Dual{
178178
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
179179
eltype(x),
180-
1,
180+
1
181181
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
182182
partials.(g(y), 1)
183183
end
@@ -272,10 +272,10 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
272272
(cache1, cache2), num_jacvec, num_jacvec!
273273
elseif autodiff isa AutoForwardDiff
274274
cache1 = Dual{
275-
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1,
275+
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
276276
}.(u, ForwardDiff.Partials.(tuple.(u)))
277277
cache2 = Dual{
278-
typeof(ForwardDiff.Tag(tag, eltype(fu))), eltype(fu), 1,
278+
typeof(ForwardDiff.Tag(tag, eltype(fu))), eltype(fu), 1
279279
}.(fu, ForwardDiff.Partials.(tuple.(fu)))
280280

281281
(cache1, cache2), auto_jacvec, auto_jacvec!
@@ -307,7 +307,7 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing;
307307
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
308308

309309
cache1 = Dual{
310-
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1,
310+
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
311311
}.(u, ForwardDiff.Partials.(tuple.(u)))
312312
cache2 = copy(cache1)
313313

@@ -338,7 +338,7 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing;
338338
(cache1, cache2), num_hesvecgrad, num_hesvecgrad!
339339
elseif autodiff isa AutoForwardDiff
340340
cache1 = Dual{
341-
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1,
341+
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
342342
}.(u, ForwardDiff.Partials.(tuple.(u)))
343343
cache2 = copy(cache1)
344344

src/highlevel/coloring.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ end
3232

3333
# Approximate Jacobian Sparsity Detection
3434
## Right now we hardcode it to use `ForwardDiff`
35-
function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f::F, x; fx = nothing,
35+
function (alg::ApproximateJacobianSparsity)(
36+
ad::AbstractSparseADType, f::F, x; fx = nothing,
3637
kwargs...) where {F}
3738
if !(ad isa AutoSparseForwardDiff)
3839
@warn "$(ad) support for approximate jacobian not implemented. Using ForwardDiff instead." maxlog=1
@@ -71,7 +72,8 @@ function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f::F, fx,
7172
fx, kwargs...)
7273
end
7374

74-
function (alg::ApproximateJacobianSparsity)(ad::AutoSparseFiniteDiff, f::F, x; fx = nothing,
75+
function (alg::ApproximateJacobianSparsity)(
76+
ad::AutoSparseFiniteDiff, f::F, x; fx = nothing,
7577
kwargs...) where {F}
7678
@unpack ntrials, rng = alg
7779
fx = fx === nothing ? f(x) : fx
@@ -101,7 +103,8 @@ function (alg::ApproximateJacobianSparsity)(ad::AutoSparseFiniteDiff, f!::F, fx,
101103
FiniteDiff.finite_difference_jacobian!(J_cache, f!, x_, cache)
102104
@. J += (abs(J_cache) .≥ ε) # hedge against numerical issues
103105
end
104-
return (JacPrototypeSparsityDetection(; jac_prototype = sparse(J), alg.alg))(ad, f!, fx,
106+
return (JacPrototypeSparsityDetection(; jac_prototype = sparse(J), alg.alg))(
107+
ad, f!, fx,
105108
x; kwargs...)
106109
end
107110

src/highlevel/common.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,10 @@ function init_jacobian end
269269
const __init_𝒥 = init_jacobian
270270

271271
# Misc Functions
272-
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
273-
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}}, x) where {C}
272+
function __chunksize(
273+
::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
274+
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}},
275+
x) where {C}
274276
C isa ForwardDiff.Chunk && return C
275277
return __chunksize(Val(C), x)
276278
end

test/test_jaches_products.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), e
7878
cache3 = ForwardDiff.Dual{
7979
typeof(ForwardDiff.Tag(Nothing, eltype(x))),
8080
eltype(x),
81-
1,
81+
1
8282
}.(x, ForwardDiff.Partials.(tuple.(v)))
8383
cache4 = ForwardDiff.Dual{
8484
typeof(ForwardDiff.Tag(Nothing, eltype(x))),
8585
eltype(x),
86-
1,
86+
1
8787
}.(x, ForwardDiff.Partials.(tuple.(v)))
8888
@test autoback_hesvec!(dy, g, x, v) ForwardDiff.hessian(g, x) * v
8989
@test autoback_hesvec!(dy, g, x, v, cache3, cache4) ForwardDiff.hessian(g, x) * v

0 commit comments

Comments
 (0)