Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform Step Refactor for SDIRK #2440

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
137 changes: 23 additions & 114 deletions lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ end
z₂::uType
z₃::uType
z₄::uType
k1::kType
k2::kType
k3::kType
k4::kType
ks::Vector{kType}
atmp::uNoUnitsType
nlsolver::N
tab::Tab
Expand All @@ -98,15 +95,9 @@ function alg_cache(alg::KenCarp3, u, rate_prototype, ::Type{uEltypeNoUnits},
fsalfirst = zero(rate_prototype)

if f isa SplitFunction
k1 = zero(u)
k2 = zero(u)
k3 = zero(u)
k4 = zero(u)
ks = [zero(u) for _ in 1:4]
else
k1 = nothing
k2 = nothing
k3 = nothing
k4 = nothing
ks = [nothing for _ in 1:4]
uf = UJacobianWrapper(f, t, p)
end

Expand All @@ -117,8 +108,7 @@ function alg_cache(alg::KenCarp3, u, rate_prototype, ::Type{uEltypeNoUnits},
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)

KenCarp3Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k1, k2,
k3, k4, atmp, nlsolver, tab, alg.step_limiter!)
KenCarp3Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, ks, atmp, nlsolver, tab, alg.step_limiter!)
end

@cache mutable struct CFNLIRK3ConstantCache{N, Tab} <: SDIRKConstantCache
Expand Down Expand Up @@ -147,10 +137,7 @@ end
z₂::uType
z₃::uType
z₄::uType
k1::kType
k2::kType
k3::kType
k4::kType
ks::Vector{kType}
atmp::uNoUnitsType
nlsolver::N
tab::Tab
Expand All @@ -166,10 +153,7 @@ function alg_cache(alg::CFNLIRK3, u, rate_prototype, ::Type{uEltypeNoUnits},
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true))
fsalfirst = zero(rate_prototype)

k1 = zero(u)
k2 = zero(u)
k3 = zero(u)
k4 = zero(u)
ks = [zero(u) for _ in 1:4]

z₁ = zero(u)
z₂ = zero(u)
Expand All @@ -178,7 +162,7 @@ function alg_cache(alg::CFNLIRK3, u, rate_prototype, ::Type{uEltypeNoUnits},
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)

CFNLIRK3Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp, nlsolver, tab)
CFNLIRK3Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, ks, atmp, nlsolver, tab)
end

@cache mutable struct Kvaerno4ConstantCache{N, Tab} <: SDIRKConstantCache
Expand Down Expand Up @@ -263,12 +247,7 @@ end
z₄::uType
z₅::uType
z₆::uType
k1::kType
k2::kType
k3::kType
k4::kType
k5::kType
k6::kType
ks::Vector{kType}
atmp::uNoUnitsType
nlsolver::N
tab::Tab
Expand All @@ -288,19 +267,9 @@ function alg_cache(alg::KenCarp4, u, rate_prototype, ::Type{uEltypeNoUnits},
fsalfirst = zero(rate_prototype)

if f isa SplitFunction
k1 = zero(u)
k2 = zero(u)
k3 = zero(u)
k4 = zero(u)
k5 = zero(u)
k6 = zero(u)
ks = [zero(u) for _ in 1:6]
else
k1 = nothing
k2 = nothing
k3 = nothing
k4 = nothing
k5 = nothing
k6 = nothing
ks = [nothing for _ in 1:6]
uf = UJacobianWrapper(f, t, p)
end

Expand All @@ -314,7 +283,7 @@ function alg_cache(alg::KenCarp4, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)

KenCarp4Cache(
u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, k1, k2, k3, k4, k5, k6, atmp,
u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, ks, atmp,
nlsolver, tab, alg.step_limiter!)
end

Expand Down Expand Up @@ -408,14 +377,7 @@ end
z₆::uType
z₇::uType
z₈::uType
k1::kType
k2::kType
k3::kType
k4::kType
k5::kType
k6::kType
k7::kType
k8::kType
ks::Vector{kType}
atmp::uNoUnitsType
nlsolver::N
tab::Tab
Expand All @@ -433,23 +395,9 @@ function alg_cache(alg::KenCarp5, u, rate_prototype, ::Type{uEltypeNoUnits},
fsalfirst = zero(rate_prototype)

if f isa SplitFunction
k1 = zero(u)
k2 = zero(u)
k3 = zero(u)
k4 = zero(u)
k5 = zero(u)
k6 = zero(u)
k7 = zero(u)
k8 = zero(u)
ks = [zero(u) for _ in 1:8]
else
k1 = nothing
k2 = nothing
k3 = nothing
k4 = nothing
k5 = nothing
k6 = nothing
k7 = nothing
k8 = nothing
ks = [nothing for _ in 1:8]
end

z₁ = zero(u)
Expand All @@ -464,7 +412,7 @@ function alg_cache(alg::KenCarp5, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)

KenCarp5Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈,
k1, k2, k3, k4, k5, k6, k7, k8, atmp, nlsolver, tab, alg.step_limiter!)
ks, atmp, nlsolver, tab, alg.step_limiter!)
end

@cache mutable struct KenCarp47ConstantCache{N, Tab} <: SDIRKConstantCache
Expand Down Expand Up @@ -496,13 +444,7 @@ end
z₅::uType
z₆::uType
z₇::uType
k1::kType
k2::kType
k3::kType
k4::kType
k5::kType
k6::kType
k7::kType
ks::Vector{kType}
atmp::uNoUnitsType
nlsolver::N
tab::Tab
Expand All @@ -520,21 +462,9 @@ function alg_cache(alg::KenCarp47, u, rate_prototype, ::Type{uEltypeNoUnits},
fsalfirst = zero(rate_prototype)

if f isa SplitFunction
k1 = zero(u)
k2 = zero(u)
k3 = zero(u)
k4 = zero(u)
k5 = zero(u)
k6 = zero(u)
k7 = zero(u)
ks = [zero(u) for _ in 1:7]
else
k1 = nothing
k2 = nothing
k3 = nothing
k4 = nothing
k5 = nothing
k6 = nothing
k7 = nothing
ks = [nothing for _ in 1:7]
end

z₁ = zero(u)
Expand All @@ -548,7 +478,7 @@ function alg_cache(alg::KenCarp47, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)

KenCarp47Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇,
k1, k2, k3, k4, k5, k6, k7, atmp, nlsolver, tab)
ks, atmp, nlsolver, tab)
end

@cache mutable struct KenCarp58ConstantCache{N, Tab} <: SDIRKConstantCache
Expand Down Expand Up @@ -581,14 +511,7 @@ end
z₆::uType
z₇::uType
z₈::uType
k1::kType
k2::kType
k3::kType
k4::kType
k5::kType
k6::kType
k7::kType
k8::kType
ks::Vector{kType}
atmp::uNoUnitsType
nlsolver::N
tab::Tab
Expand All @@ -607,23 +530,9 @@ function alg_cache(alg::KenCarp58, u, rate_prototype, ::Type{uEltypeNoUnits},
fsalfirst = zero(rate_prototype)

if f isa SplitFunction
k1 = zero(u)
k2 = zero(u)
k3 = zero(u)
k4 = zero(u)
k5 = zero(u)
k6 = zero(u)
k7 = zero(u)
k8 = zero(u)
ks = [zero(u) for _ in 1:8]
else
k1 = nothing
k2 = nothing
k3 = nothing
k4 = nothing
k5 = nothing
k6 = nothing
k7 = nothing
k8 = nothing
ks = [nothing for _ in 1:8]
end

z₁ = zero(u)
Expand All @@ -638,5 +547,5 @@ function alg_cache(alg::KenCarp58, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)

KenCarp58Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈,
k1, k2, k3, k4, k5, k6, k7, k8, atmp, nlsolver, tab)
ks, atmp, nlsolver, tab)
end
Loading
Loading