From 7f311e26ff91cdff22aa77698042b63dcf99f6f5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 6 Aug 2024 16:06:39 +0530 Subject: [PATCH] fixup! feat: support parameter updates in `initialize_dae!` --- Project.toml | 2 ++ lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl | 5 ++++- src/OrdinaryDiffEq.jl | 2 ++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 24ff766f71..3f2053035c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "6.87.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -76,6 +77,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [compat] ADTypes = "0.2, 1" +Accessors = "0.1.36" Adapt = "3.0, 4" ArrayInterface = "7" DataStructures = "0.18" diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl index bd1c4b1249..2a2b0c5c9c 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl @@ -52,8 +52,11 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, else error("Unreachable reached. Report this error.") end - if isdefined(prob.f, :initializeprobpmap) && prob.f.initializeprobpmap !== nothing + if SciMLBase.has_initializeprobpmap(prob.f) integrator.p = prob.f.initializeprobpmap(prob, nlsol) + sol = integrator.sol + @reset sol.prob.p = integrator.p + integrator.sol = sol end if nlsol.retcode != ReturnCode.Success diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index b012ef46f8..085779b68d 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -63,6 +63,8 @@ import OrdinaryDiffEqCore: trivial_limiter!, CompositeAlgorithm, alg_order, export CompositeAlgorithm, ShampineCollocationInit, BrownFullBasicInit, NoInit AutoSwitch +using Accessors: @reset + import OrdinaryDiffEqDifferentiation using OrdinaryDiffEqDifferentiation: _alg_autodiff, resize_grad_config!, dolinsolve, wrapprecs, UJacobianWrapper, build_jac_config,