From 20646806d9b5484152c883711419d505677a9ae4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 6 Aug 2024 16:06:39 +0530 Subject: [PATCH] fixup! feat: support parameter updates in `initialize_dae!` --- Project.toml | 2 ++ src/OrdinaryDiffEq.jl | 2 ++ src/initialize_dae.jl | 5 ++++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f804b5652a..9f2514b507 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "6.87.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -46,6 +47,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [compat] ADTypes = "0.2, 1" +Accessors = "0.1.36" Adapt = "3.0, 4" ArrayInterface = "7" DataStructures = "0.18" diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index c11fc6d414..895e0604a6 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -14,6 +14,8 @@ using Reexport using Logging +using Accessors: @reset + using MuladdMacro, SparseArrays, FastClosures using LinearAlgebra diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 96c9d932c4..1c9e3c4b74 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -164,8 +164,11 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, else error("Unreachable reached. Report this error.") end - if isdefined(prob.f, :initializeprobpmap) && prob.f.initializeprobpmap !== nothing + if SciMLBase.has_initializeprobpmap(prob.f) integrator.p = prob.f.initializeprobpmap(prob, nlsol) + sol = integrator.sol + @reset sol.prob.p = integrator.p + integrator.sol = sol end if nlsol.retcode != ReturnCode.Success