This repository was archived by the owner on Apr 23, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathSparseDiffToolsPolyesterExt.jl
71 lines (62 loc) · 2.29 KB
/
SparseDiffToolsPolyesterExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
module SparseDiffToolsPolyesterExt
using Adapt, ArrayInterface, ForwardDiff, FiniteDiff, Polyester, SparseDiffTools,
SparseArrays
import SparseDiffTools: polyesterforwarddiff_color_jacobian, ForwardColorJacCache,
__parameterless_type
function cld_fast(a::A, b::B) where {A, B}
T = promote_type(A, B)
return cld_fast(a % T, b % T)
end
function cld_fast(n::T, d::T) where {T}
x = Base.udiv_int(n, d)
x += n != d * x
return x
end
function polyesterforwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
x::AbstractArray{<:Number}, jac_cache::ForwardColorJacCache) where {F}
t = jac_cache.t
dx = jac_cache.dx
p = jac_cache.p
colorvec = jac_cache.colorvec
sparsity = jac_cache.sparsity
chunksize = jac_cache.chunksize
maxcolor = maximum(colorvec)
vecx = vec(x)
nrows, ncols = size(J)
if !(sparsity isa Nothing)
rows_index, cols_index = ArrayInterface.findstructralnz(sparsity)
rows_index = [rows_index[i] for i in 1:length(rows_index)]
cols_index = [cols_index[i] for i in 1:length(cols_index)]
else
cartind = vec(CartesianIndices(x))
rows_index = Base.Iterators.map(first ∘ Tuple, cartind)
cols_index = Base.Iterators.map(last ∘ Tuple, cartind)
end
if J isa AbstractSparseMatrix
fill!(nonzeros(J), zero(eltype(J)))
else
fill!(J, zero(eltype(J)))
end
batch((length(p), min(length(p), Threads.nthreads()))) do _, start, stop
color_i = (start - 1) * chunksize + 1
for i in start:stop
partial_i = p[i]
t_ = reshape(eltype(t).(vecx, ForwardDiff.Partials.(partial_i)), size(t))
fx = f(t_)
for j in 1:chunksize
dx = vec(ForwardDiff.partials.(fx, j))
pick_inds = [idx
for idx in 1:length(rows_index)
if colorvec[cols_index[idx]] == color_i]
rows_index_c = rows_index[pick_inds]
cols_index_c = cols_index[pick_inds]
@simd for i in eachindex(rows_index_c, cols_index_c)
J[rows_index_c[i], cols_index_c[i]] = dx[rows_index_c[i]]
end
color_i += 1
end
end
end
return J
end
end