Skip to content

Commit

Permalink
Add hook for UserWorkBeforeLoop
Browse files Browse the repository at this point in the history
  • Loading branch information
jonahm-LANL committed Nov 6, 2023
1 parent 2857394 commit 06657a4
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 3 deletions.
7 changes: 7 additions & 0 deletions doc/sphinx/src/interface/state.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ several useful features and functions.
deletgates to the ``std::function`` member ``PostStepDiagnosticsMesh``
if set (defaults to ``nullptr`` an therefore a no-op) to print
diagnostics after the time-integration advance
- ``void UserWorkBeforeLoopMesh(Mesh *, ParameterInput *pin, SimTime
&tm)`` performs a per-package, mesh-wide calculation after the mesh
has been generated, and problem generators called, but before any
time evolution. This work is done both on first initialization and
on restart. If you would like to avoid doing the work upon restart,
you can check for the const ``is_restart`` member field of the ``Mesh``
object.

The reasoning for providing ``FillDerived*`` and ``EstimateTimestep*``
function pointers appropriate for usage with both ``MeshData`` and
Expand Down
10 changes: 10 additions & 0 deletions example/advection/advection_package.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include <vector>

#include <coordinates/coordinates.hpp>
#include <globals.hpp>
#include <parthenon/package.hpp>

#include "advection_package.hpp"
Expand Down Expand Up @@ -219,6 +221,14 @@ std::shared_ptr<StateDescriptor> Initialize(ParameterInput *pin) {
return pkg;
}

void AdvectionGreetings(Mesh *pmesh, ParameterInput *pin, SimTime &tm) {
if (GLobals::my_rank == 0) {
std::cout << "Hello from the advection package in the advection example!\n"
<< "This run is a restart: " << pmesh->is_restart
<< std::endl;
}
}

AmrTag CheckRefinement(MeshBlockData<Real> *rc) {
// refine on advected, for example. could also be a derived quantity
auto pmb = rc->GetBlockPointer();
Expand Down
1 change: 1 addition & 0 deletions example/advection/advection_package.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace advection_package {
using namespace parthenon::package::prelude;

std::shared_ptr<StateDescriptor> Initialize(ParameterInput *pin);
void AdvectionGreetings(Mesh *pmes, ParameterInput *pin, SimTime &tm);
AmrTag CheckRefinement(MeshBlockData<Real> *rc);
void PreFill(MeshBlockData<Real> *rc);
void SquareIt(MeshBlockData<Real> *rc);
Expand Down
1 change: 1 addition & 0 deletions src/application_input.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct ApplicationInput {
PostStepDiagnosticsInLoop = nullptr;

std::function<void(Mesh *, ParameterInput *, SimTime &)> UserWorkAfterLoop = nullptr;
std::function<void(Mesh *, ParameterInput *, SimTime &)> UserWorkBeforeLoop = nullptr;
BValFunc boundary_conditions[BOUNDARY_NFACES] = {nullptr};
SBValFunc swarm_boundary_conditions[BOUNDARY_NFACES] = {nullptr};

Expand Down
10 changes: 10 additions & 0 deletions src/driver/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ DriverStatus EvolutionDriver::Execute() {
// Defaults must be set across all ranks
DumpInputParameters();

// Before loop do work
// App input version
if (app_input->UserWorkBeforeLoop != nullptr) {
app_input->UserWorkBeforeLoop(pmesh, pinput, tm);
}
// packages version
for (auto &[name, pkg] : pmesh->packages.AllPackages()) {
pkg->UserWorkBeforeLoop(pmesh, pinput, tm);
}

Kokkos::Profiling::pushRegion("Driver_Main");
while (tm.KeepGoing()) {
if (Globals::my_rank == 0) OutputCycleDiagnostics();
Expand Down
13 changes: 10 additions & 3 deletions src/interface/state_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,18 @@ class StateDescriptor {
// one can pass in a reference to a SparsePool or arguments that match one of the
// SparsePool constructors
template <typename... Args>
bool AddSparsePool(Args &&...args) {
bool AddSparsePool(Args &&... args) {
return AddSparsePoolImpl(SparsePool(std::forward<Args>(args)...));
}
template <typename... Args>
bool AddSparsePool(const std::string &base_name, const Metadata &m_in, Args &&...args) {
bool AddSparsePool(const std::string &base_name, const Metadata &m_in,
Args &&... args) {
Metadata m = m_in; // so we can modify it
if (!m.IsSet(GetMetadataFlag())) m.Set(GetMetadataFlag());
return AddSparsePoolImpl(SparsePool(base_name, m, std::forward<Args>(args)...));
}
template <typename T, typename... Args>
bool AddSparsePool(const Metadata &m_in, Args &&...args) {
bool AddSparsePool(const Metadata &m_in, Args &&... args) {
return AddSparsePool(T::name(), m_in, std::forward<Args>(args)...);
}

Expand Down Expand Up @@ -406,6 +407,10 @@ class StateDescriptor {
if (InitNewlyAllocatedVarsBlock != nullptr) return InitNewlyAllocatedVarsBlock(rc);
}

void UserWorkBeforeLoop(Mesh *pmesh, ParameterInput *pin, SimTime &tm) const {
if (UserWorkBeforeLoopMesh != nullptr) return UserWorkBeforeLoopMesh(pmesh, pin, tm);
}

std::vector<std::shared_ptr<AMRCriteria>> amr_criteria;

std::function<void(MeshBlockData<Real> *rc)> PreCommFillDerivedBlock = nullptr;
Expand All @@ -416,6 +421,8 @@ class StateDescriptor {
std::function<void(MeshData<Real> *rc)> PostFillDerivedMesh = nullptr;
std::function<void(MeshBlockData<Real> *rc)> FillDerivedBlock = nullptr;
std::function<void(MeshData<Real> *rc)> FillDerivedMesh = nullptr;
std::function<void(Mesh *, ParameterInput *, SimTime &)> UserWorkBeforeLoopMesh =
nullptr;

std::function<void(SimTime const &simtime, MeshData<Real> *rc)> PreStepDiagnosticsMesh =
nullptr;
Expand Down
2 changes: 2 additions & 0 deletions src/mesh/mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Mesh::Mesh(ParameterInput *pin, ApplicationInput *app_in, Packages_t &packages,
int mesh_test)
: // public members:
modified(true),
is_restart(false),
// aggregate initialization of RegionSize struct:
mesh_size({pin->GetReal("parthenon/mesh", "x1min"),
pin->GetReal("parthenon/mesh", "x2min"),
Expand Down Expand Up @@ -485,6 +486,7 @@ Mesh::Mesh(ParameterInput *pin, ApplicationInput *app_in, RestartReader &rr,
// aggregate initialization of RegionSize struct:
// (will be overwritten by memcpy from restart file, in this case)
modified(true),
is_restart(true),
// aggregate initialization of RegionSize struct:
mesh_size({pin->GetReal("parthenon/mesh", "x1min"),
pin->GetReal("parthenon/mesh", "x2min"),
Expand Down
1 change: 1 addition & 0 deletions src/mesh/mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Mesh {

// data
bool modified;
const bool is_restart;
RegionSize mesh_size;
BoundaryFlag mesh_bcs[BOUNDARY_NFACES];
const int ndim; // number of dimensions
Expand Down

0 comments on commit 06657a4

Please sign in to comment.