Skip to content

Commit

Permalink
Sundials Update (AMReX-Codes#3984)
Browse files Browse the repository at this point in the history
## Summary
Changes to the time integrator interface to now support explicit,
implicit, and ImEx methods with fixed or adaptive time step size, as
well as MRI approaches.

## Additional background
Can be tested with amrex-tutorials PR
AMReX-Codes/amrex-tutorials#123

---------

Co-authored-by: David J. Gardner <[email protected]>
  • Loading branch information
ajnonaka and gardner48 authored Jun 21, 2024
1 parent adfc227 commit 259db7c
Show file tree
Hide file tree
Showing 5 changed files with 837 additions and 793 deletions.
70 changes: 56 additions & 14 deletions Src/Base/AMReX_FEIntegrator.H
Original file line number Diff line number Diff line change
Expand Up @@ -15,50 +15,92 @@ private:

amrex::Vector<std::unique_ptr<T> > F_nodes;

void initialize_stages (const T& S_data)
// Current (internal) state and time
amrex::Vector<std::unique_ptr<T> > S_current;
amrex::Real time_current;

void initialize_stages (const T& S_data, const amrex::Real time)
{
// Create data for stage RHS
IntegratorOps<T>::CreateLike(F_nodes, S_data);

// Create and initialize data for current state
IntegratorOps<T>::CreateLike(S_current, S_data, true);
IntegratorOps<T>::Copy(*S_current[0], S_data);

// Set the initial time
time_current = time;
}

public:
FEIntegrator () {}

FEIntegrator (const T& S_data)
FEIntegrator (const T& S_data, const amrex::Real time = 0.0)
{
initialize(S_data);
initialize(S_data, time);
}

virtual ~FEIntegrator () {}

void initialize (const T& S_data) override
void initialize (const T& S_data, const amrex::Real time = 0.0)
{
initialize_stages(S_data);
initialize_stages(S_data, time);
}

amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override
amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override
{
BaseT::timestep = time_step;
// Assume before advance() that S_old is valid data at the current time ("time" argument)
// Assume before step() that S_old is valid data at the current time ("time" argument)
// So we initialize S_new by copying the old state.
IntegratorOps<T>::Copy(S_new, S_old);

// Call the pre RHS hook
BaseT::pre_rhs_action(S_new, time);

// F = RHS(S, t)
T& F = *F_nodes[0];
BaseT::rhs(F, S_new, time);
BaseT::Rhs(F, S_new, time);

// S_new += timestep * dS/dt
IntegratorOps<T>::Saxpy(S_new, BaseT::timestep, F);
IntegratorOps<T>::Saxpy(S_new, dt, F);

// Call the post-update hook for S_new
BaseT::post_update(S_new, time + BaseT::timestep);
// Call the post step hook
BaseT::post_step_action(S_new, time + dt);

// Return timestep
return BaseT::timestep;
return dt;
}

void evolve (T& S_out, const amrex::Real time_out) override
{
amrex::Real dt = BaseT::time_step;
bool stop = false;

for (int step_number = 0; step_number < BaseT::max_steps && !stop; ++step_number)
{
// Adjust step size to reach output time
if (time_out - time_current < dt) {
dt = time_out - time_current;
stop = true;
}

// Call the time integrator step
advance(*S_current[0], S_out, time_current, dt);

// Update current state S_current = S_out
IntegratorOps<T>::Copy(*S_current[0], S_out);

// Update time
time_current += dt;

if (step_number == BaseT::max_steps - 1) {
Error("Did not reach output time in max steps.");
}
}
}

virtual void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override
{
amrex::Error("Time interpolation not yet supported by forward euler integrator.");
amrex::Error("Time interpolation not yet supported by the forward euler integrator.");
}

virtual void map_data (std::function<void(T&)> Map) override
Expand Down
238 changes: 181 additions & 57 deletions Src/Base/AMReX_IntegratorBase.H
Original file line number Diff line number Diff line change
Expand Up @@ -161,37 +161,120 @@ struct IntegratorOps<T, std::enable_if_t<std::is_same_v<amrex::MultiFab, T> > >
template<class T>
class IntegratorBase
{
private:
/**
* \brief Fun is the right-hand-side function the integrator will use.
*/
std::function<void(T&, const T&, const amrex::Real)> Fun;

/**
* \brief FastFun is the fast timescale right-hand-side function for a multirate integration problem.
*/
std::function<void(T&, T&, const T&, const amrex::Real)> FastFun;

protected:
/**
* \brief Integrator timestep size (Real)
*/
amrex::Real timestep;

/**
* \brief For multirate problems, the ratio of slow timestep size / fast timestep size (int)
*/
int slow_fast_timestep_ratio = 0;

/**
* \brief For multirate problems, the fast timestep size (Real)
*/
Real fast_timestep = 0.0;

/**
* \brief The post_update function is called by the integrator on state data before using it to evaluate a right-hand side.
*/
std::function<void (T&, amrex::Real)> post_update;
/**
* \brief Rhs is the right-hand-side function the integrator will use.
*/
std::function<void(T& rhs, const T& state, const amrex::Real time)> Rhs;

/**
* \brief RhsIm is the implicit right-hand-side function an ImEx integrator
* will use.
*/
std::function<void(T& rhs, const T& state, const amrex::Real time)> RhsIm;

/**
* \brief RhsEx is the explicit right-hand-side function an ImEx integrator
* will use.
*/
std::function<void(T& rhs, const T& state, const amrex::Real time)> RhsEx;

/**
* \brief RhsFast is the fast timescale right-hand-side function a multirate
* integrator will use.
*/
std::function<void(T& rhs, const T& state, const amrex::Real time)> RhsFast;

/**
* \brief The pre_rhs_action function is called by the integrator on state
* data before using it to evaluate a right-hand side.
*/
std::function<void (T&, amrex::Real)> pre_rhs_action;

/**
* \brief The post_stage_action function is called by the integrator on
* the computed stage just after it is computed
*/
std::function<void (T&, amrex::Real)> post_stage_action;

/**
* \brief The post_step_action function is called by the integrator on
* the computed state just after it is computed
*/
std::function<void (T&, amrex::Real)> post_step_action;

/**
* \brief The post_stage_action function is called by the integrator on
* the computed stage just after it is computed
*/
std::function<void (T&, amrex::Real)> post_fast_stage_action;

/**
* \brief The post_step_action function is called by the integrator on
* the computed state just after it is computed
*/
std::function<void (T&, amrex::Real)> post_fast_step_action;

/**
* \brief Flag to enable/disable adaptive time stepping in single rate
* methods or at the slow time scale in multirate methods (bool)
*/
bool use_adaptive_time_step = false;

/**
* \brief Current integrator time step size (Real)
*/
amrex::Real time_step;

/**
* \brief Step size of the last completed step (Real)
*/
amrex::Real previous_time_step;

/**
* \brief Flag to enable/disable adaptive time stepping at the fast time
* scale in multirate methods (bool)
*/
bool use_adaptive_fast_time_step = false;

/**
* \brief Current integrator fast time scale time step size with multirate
* methods (Real)
*/
amrex::Real fast_time_step;

/**
* \brief Number of integrator time steps (Long)
*/
amrex::Long num_steps = 0;

/**
* \brief Max number of internal steps before an error is returned (Long)
*/
int max_steps = 500;

/**
* \brief Relative tolerance for adaptive time stepping (Real)
*/
amrex::Real rel_tol = 1.0e-4;

/**
* \brief Absolute tolerance for adaptive time stepping (Real)
*/
amrex::Real abs_tol = 1.0e-9;

/**
* \brief Relative tolerance for adaptive time stepping at the fast time
* scale (Real)
*/
amrex::Real fast_rel_tol = 1.0e-4;

/**
* \brief Absolute tolerance for adaptive time stepping at the fast time
* scale (Real)
*/
amrex::Real fast_abs_tol = 1.0e-9;


public:
IntegratorBase () = default;
Expand All @@ -200,71 +283,112 @@ public:

virtual ~IntegratorBase () = default;

virtual void initialize (const T& S_data) = 0;

void set_rhs (std::function<void(T&, const T&, const amrex::Real)> F)
{
Fun = F;
Rhs = F;
}

void set_imex_rhs (std::function<void(T&, const T&, const amrex::Real)> Fi,
std::function<void(T&, const T&, const amrex::Real)> Fe)
{
RhsIm = Fi;
RhsEx = Fe;
}

void set_fast_rhs (std::function<void(T&, const T&, const amrex::Real)> F)
{
RhsFast = F;
}

void set_fast_rhs (std::function<void(T&, T&, const T&, const amrex::Real)> F)
void set_pre_rhs_action (std::function<void (T&, amrex::Real)> A)
{
FastFun = F;
pre_rhs_action = A;
}

void set_slow_fast_timestep_ratio (const int timestep_ratio = 1)
void set_post_stage_action (std::function<void (T&, amrex::Real)> A)
{
slow_fast_timestep_ratio = timestep_ratio;
post_stage_action = A;
}

void set_fast_timestep (const Real fast_dt = 1.0)
void set_post_step_action (std::function<void (T&, amrex::Real)> A)
{
fast_timestep = fast_dt;
post_step_action = A;
}

void set_post_update (std::function<void (T&, amrex::Real)> F)
void set_post_fast_stage_action (std::function<void (T&, amrex::Real)> A)
{
post_update = F;
post_fast_stage_action = A;
}

std::function<void (T&, amrex::Real)> get_post_update ()
void set_post_fast_step_action (std::function<void (T&, amrex::Real)> A)
{
return post_update;
post_fast_step_action = A;
}

std::function<void(T&, const T&, const amrex::Real)> get_rhs ()
void set_post_update (std::function<void (T&, amrex::Real)> A)
{
return Fun;
set_post_stage_action(A);
set_post_step_action(A);
}

std::function<void(T&, T&, const T&, const amrex::Real)> get_fast_rhs ()
amrex::Real get_time_step ()
{
return FastFun;
return time_step;
}

int get_slow_fast_timestep_ratio ()
void set_time_step (amrex::Real dt)
{
return slow_fast_timestep_ratio;
time_step = dt;
use_adaptive_time_step = false;
}

Real get_fast_timestep ()
void set_adaptive_step ()
{
return fast_timestep;
use_adaptive_time_step = true;
}

void rhs (T& S_rhs, const T& S_data, const amrex::Real time)
void set_fast_time_step (amrex::Real dt)
{
Fun(S_rhs, S_data, time);
fast_time_step = dt;
use_adaptive_fast_time_step = false;
}

void fast_rhs (T& S_rhs, T& S_extra, const T& S_data, const amrex::Real time)
void set_adaptive_fast_step ()
{
FastFun(S_rhs, S_extra, S_data, time);
use_adaptive_fast_time_step = true;
}

virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0;
void set_max_steps (int steps)
{
max_steps = steps;
}

void set_tolerances (amrex::Real rtol, amrex::Real atol)
{
rel_tol = rtol;
abs_tol = atol;
}

void set_fast_tolerances (amrex::Real rtol, amrex::Real atol)
{
fast_rel_tol = rtol;
fast_abs_tol = atol;
}

/**
* \brief Take a single time step from (time, S_old) to (time + dt, S_new)
* with the given step size.
*/
virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time,
amrex::Real dt) = 0;

/**
* \brief Evolve the current (internal) integrator state to time_out
*/
virtual void evolve (T& S_out, const amrex::Real time_out) = 0;

virtual void time_interpolate (const T& S_new, const T& S_old, amrex::Real timestep_fraction, T& data) = 0;
virtual void time_interpolate (const T& S_new, const T& S_old,
amrex::Real timestep_fraction, T& data) = 0;

virtual void map_data (std::function<void(T&)> Map) = 0;
};
Expand Down
Loading

0 comments on commit 259db7c

Please sign in to comment.