Skip to content

Commit

Permalink
finish cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Nov 21, 2024
1 parent 37500d2 commit 4789e0e
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 72 deletions.
2 changes: 1 addition & 1 deletion example/poisson_gmg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ if( "poisson-gmg-example" IN_LIST DRIVER_LIST OR NOT PARTHENON_DISABLE_EXAMPLES)
poisson-gmg-example
poisson_driver.cpp
poisson_driver.hpp
poisson_equation_stages.hpp
poisson_equation.hpp
poisson_package.cpp
poisson_package.hpp
main.cpp
Expand Down
25 changes: 13 additions & 12 deletions example/poisson_gmg/poisson_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
#include "mesh/meshblock_pack.hpp"
#include "parthenon/driver.hpp"
#include "poisson_driver.hpp"
#include "poisson_equation_stages.hpp"
#include "poisson_equation.hpp"
#include "poisson_package.hpp"
#include "prolong_restrict/prolong_restrict.hpp"
#include "solvers/bicgstab_solver_stages.hpp"
#include "solvers/cg_solver_stages.hpp"
#include "solvers/mg_solver_stages.hpp"
#include "solvers/bicgstab_solver.hpp"
#include "solvers/cg_solver.hpp"
#include "solvers/mg_solver.hpp"
#include "solvers/solver_utils.hpp"

using namespace parthenon::driver::prelude;
Expand Down Expand Up @@ -76,17 +76,18 @@ TaskCollection PoissonDriver::MakeTaskCollection(BlockList_t &blocks) {

// Move the rhs variable into the rhs stage for stage based solver
auto copy_rhs = tl.AddTask(none, TF(solvers::utils::CopyData<rhs, u>), md);
copy_rhs = tl.AddTask(
copy_rhs, TF(solvers::utils::CopyData<parthenon::TypeList<u>>), md, md_rhs);
copy_rhs = tl.AddTask(copy_rhs, TF(solvers::utils::CopyData<parthenon::TypeList<u>>),
md, md_rhs);

// Possibly set rhs <- A.u_exact for a given u_exact so that the exact solution is
// known when we solve A.u = rhs
if (use_exact_rhs) {
auto copy_exact = tl.AddTask(copy_rhs, TF(solvers::utils::CopyData<exact, u>), md);
copy_exact = tl.AddTask(
copy_rhs, TF(solvers::utils::CopyData<parthenon::TypeList<u>>), md, md_u);
copy_rhs, TF(solvers::utils::CopyData<parthenon::TypeList<u>>), md, md_u);
auto comm = AddBoundaryExchangeTasks<BoundaryType::any>(copy_exact, tl, md_u, true);
auto *eqs = pkg->MutableParam<poisson_package::PoissonEquation<u, D>>("poisson_equation");
auto *eqs =
pkg->MutableParam<poisson_package::PoissonEquation<u, D>>("poisson_equation");
copy_rhs = eqs->Ax(tl, comm, md, md_u, md_rhs);
}

Expand All @@ -100,9 +101,9 @@ TaskCollection PoissonDriver::MakeTaskCollection(BlockList_t &blocks) {
// solution to the exact solution
if (use_exact_rhs) {
auto copy_back = tl.AddTask(
solve, TF(solvers::utils::CopyData<parthenon::TypeList<u>>), md_u, md);
auto diff = tl.AddTask(copy_back, TF(solvers::utils::AddFieldsAndStore<exact, u, u>),
md, 1.0, -1.0);
solve, TF(solvers::utils::CopyData<parthenon::TypeList<u>>), md_u, md);
auto diff = tl.AddTask(
copy_back, TF(solvers::utils::AddFieldsAndStore<exact, u, u>), md, 1.0, -1.0);
auto get_err = solvers::utils::DotProduct<u, u>(diff, tl, &err, md);
tl.AddTask(
get_err,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
// license in this material to reproduce, prepare derivative works, distribute copies to
// the public, perform publicly and display publicly, and to permit others to do so.
//========================================================================================
#ifndef EXAMPLE_POISSON_GMG_POISSON_EQUATION_STAGES_HPP_
#define EXAMPLE_POISSON_GMG_POISSON_EQUATION_STAGES_HPP_
#ifndef EXAMPLE_POISSON_GMG_POISSON_EQUATION_HPP_
#define EXAMPLE_POISSON_GMG_POISSON_EQUATION_HPP_

#include <memory>
#include <set>
Expand Down Expand Up @@ -315,4 +315,4 @@ class PoissonEquation {

} // namespace poisson_package

#endif // EXAMPLE_POISSON_GMG_POISSON_EQUATION_STAGES_HPP_
#endif // EXAMPLE_POISSON_GMG_POISSON_EQUATION_HPP_
25 changes: 11 additions & 14 deletions example/poisson_gmg/poisson_package.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
#include <coordinates/coordinates.hpp>
#include <parthenon/driver.hpp>
#include <parthenon/package.hpp>
#include <solvers/bicgstab_solver_stages.hpp>
#include <solvers/cg_solver_stages.hpp>
#include <solvers/bicgstab_solver.hpp>
#include <solvers/cg_solver.hpp>
#include <solvers/solver_utils.hpp>

#include "defs.hpp"
#include "kokkos_abstraction.hpp"
#include "poisson_equation_stages.hpp"
#include "poisson_equation.hpp"
#include "poisson_package.hpp"

using namespace parthenon::package::prelude;
Expand Down Expand Up @@ -94,20 +94,17 @@ std::shared_ptr<StateDescriptor> Initialize(ParameterInput *pin) {

std::shared_ptr<parthenon::solvers::SolverBase> psolver;
using prolongator_t = parthenon::solvers::ProlongationBlockInteriorDefault;
using preconditioner_t =
parthenon::solvers::MGSolver<PoissEq, prolongator_t>;
using preconditioner_t = parthenon::solvers::MGSolver<PoissEq, prolongator_t>;
if (solver == "MG") {
psolver = std::make_shared<
parthenon::solvers::MGSolver<PoissEq, prolongator_t>>(
psolver = std::make_shared<parthenon::solvers::MGSolver<PoissEq, prolongator_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEq(pin, "poisson"));
} else if (solver == "CG") {
psolver = std::make_shared<
parthenon::solvers::CGSolver<PoissEq, preconditioner_t>>(
psolver = std::make_shared<parthenon::solvers::CGSolver<PoissEq, preconditioner_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEq(pin, "poisson"));
} else if (solver == "BiCGSTAB") {
psolver = std::make_shared<
parthenon::solvers::BiCGSTABSolver<PoissEq, preconditioner_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEq(pin, "poisson"));
psolver =
std::make_shared<parthenon::solvers::BiCGSTABSolver<PoissEq, preconditioner_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEq(pin, "poisson"));
} else {
PARTHENON_FAIL("Unknown solver type.");
}
Expand All @@ -122,8 +119,8 @@ std::shared_ptr<StateDescriptor> Initialize(ParameterInput *pin) {
// for the standard Poisson equation.
pkg->AddField(D::name(), mD);

std::vector<MetadataFlag> flags{Metadata::Cell, Metadata::Independent,
Metadata::FillGhost, Metadata::WithFluxes,
std::vector<MetadataFlag> flags{Metadata::Cell, Metadata::Independent,
Metadata::FillGhost, Metadata::WithFluxes,
Metadata::GMGRestrict, Metadata::GMGProlongate};
auto mflux_comm = Metadata(flags);
if (prolong == "Linear") {
Expand Down
6 changes: 3 additions & 3 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,10 @@ add_library(parthenon
amr_criteria/refinement_package.cpp
amr_criteria/refinement_package.hpp

solvers/bicgstab_solver_stages.hpp
solvers/cg_solver_stages.hpp
solvers/bicgstab_solver.hpp
solvers/cg_solver.hpp
solvers/internal_prolongation.hpp
solvers/mg_solver_stages.hpp
solvers/mg_solver.hpp
solvers/solver_base.hpp
solvers/solver_utils.hpp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
// license in this material to reproduce, prepare derivative works, distribute copies to
// the public, perform publicly and display publicly, and to permit others to do so.
//========================================================================================
#ifndef SOLVERS_BICGSTAB_SOLVER_STAGES_HPP_
#define SOLVERS_BICGSTAB_SOLVER_STAGES_HPP_
#ifndef SOLVERS_BICGSTAB_SOLVER_HPP_
#define SOLVERS_BICGSTAB_SOLVER_HPP_

#include <cstdio>
#include <memory>
Expand All @@ -23,7 +23,7 @@
#include "interface/meshblock_data.hpp"
#include "interface/state_descriptor.hpp"
#include "kokkos_abstraction.hpp"
#include "solvers/mg_solver_stages.hpp"
#include "solvers/mg_solver.hpp"
#include "solvers/solver_base.hpp"
#include "solvers/solver_utils.hpp"
#include "tasks/tasks.hpp"
Expand Down Expand Up @@ -87,12 +87,13 @@ class BiCGSTABSolver : public SolverBase {
// Internal containers for solver which create deep copies of sol_fields
std::string container_rhat0, container_v, container_h, container_s;
std::string container_t, container_r, container_p, container_x, container_diag;

static inline std::size_t id{0};

public:
BiCGSTABSolver(const std::string &container_base, const std::string &container_u,
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, equations eq_in = equations())
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, equations eq_in = equations())
: preconditioner(container_base, container_u, container_rhs, pin, input_block,
eq_in),
container_base(container_base), container_u(container_u),
Expand Down Expand Up @@ -174,8 +175,8 @@ class BiCGSTABSolver : public SolverBase {
this);
tl.AddTask(
TaskQualifier::once_per_region, initialize, "print to screen",
[&](BiCGSTABSolver *solver, std::shared_ptr<Real> res_tol,
bool relative_residual, Mesh *pm) {
[&](BiCGSTABSolver *solver, std::shared_ptr<Real> res_tol, bool relative_residual,
Mesh *pm) {
if (Globals::my_rank == 0 && params_.print_per_step) {
Real tol = relative_residual
? *res_tol * std::sqrt(solver->rhs2.val / pm->GetTotalCells())
Expand Down Expand Up @@ -372,4 +373,4 @@ class BiCGSTABSolver : public SolverBase {

} // namespace parthenon

#endif // SOLVERS_BICGSTAB_SOLVER_STAGES_HPP_
#endif // SOLVERS_BICGSTAB_SOLVER_HPP_
23 changes: 12 additions & 11 deletions src/solvers/cg_solver_stages.hpp → src/solvers/cg_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
// license in this material to reproduce, prepare derivative works, distribute copies to
// the public, perform publicly and display publicly, and to permit others to do so.
//========================================================================================
#ifndef SOLVERS_CG_SOLVER_STAGES_HPP_
#define SOLVERS_CG_SOLVER_STAGES_HPP_
#ifndef SOLVERS_CG_SOLVER_HPP_
#define SOLVERS_CG_SOLVER_HPP_

#include <cstdio>
#include <limits>
Expand All @@ -24,7 +24,7 @@
#include "interface/meshblock_data.hpp"
#include "interface/state_descriptor.hpp"
#include "kokkos_abstraction.hpp"
#include "solvers/mg_solver_stages.hpp"
#include "solvers/mg_solver.hpp"
#include "solvers/solver_base.hpp"
#include "solvers/solver_utils.hpp"
#include "tasks/tasks.hpp"
Expand Down Expand Up @@ -79,18 +79,19 @@ class CGSolver : public SolverBase {
std::string container_x, container_r, container_v, container_p;

static inline std::size_t id{0};

public:
CGSolver(const std::string &container_base, const std::string &container_u,
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, const equations &eq_in = equations())
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, const equations &eq_in = equations())
: preconditioner(container_base, container_u, container_rhs, pin, input_block,
eq_in),
container_base(container_base), container_u(container_u),
container_rhs(container_rhs), params_(pin, input_block), iter_counter(0),
eqs_(eq_in) {
FieldTL::IterateTypes(
[this](auto t) { this->sol_fields.push_back(decltype(t)::name()); });
std::string solver_id = "cg" + std::to_string(id++);
std::string solver_id = "cg" + std::to_string(id++);
container_x = solver_id + "_x";
container_r = solver_id + "_r";
container_v = solver_id + "_v";
Expand Down Expand Up @@ -146,8 +147,8 @@ class CGSolver : public SolverBase {
if (params_.print_per_step && Globals::my_rank == 0) {
initialize = tl.AddTask(
TaskQualifier::once_per_region, initialize, "print to screen",
[&](CGSolver *solver, std::shared_ptr<Real> res_tol,
bool relative_residual, Mesh *pm) {
[&](CGSolver *solver, std::shared_ptr<Real> res_tol, bool relative_residual,
Mesh *pm) {
Real tol = relative_residual
? *res_tol * std::sqrt(solver->rhs2.val / pm->GetTotalCells())
: *res_tol;
Expand Down Expand Up @@ -240,8 +241,8 @@ class CGSolver : public SolverBase {

auto check = itl.AddTask(
TaskQualifier::completion, get_res | correct_x, "completion",
[](CGSolver *solver, Mesh *pmesh, int max_iter,
std::shared_ptr<Real> res_tol, bool relative_residual) {
[](CGSolver *solver, Mesh *pmesh, int max_iter, std::shared_ptr<Real> res_tol,
bool relative_residual) {
Real rms_res = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
solver->final_residual = rms_res;
solver->final_iteration = solver->iter_counter;
Expand Down Expand Up @@ -278,4 +279,4 @@ class CGSolver : public SolverBase {
} // namespace solvers
} // namespace parthenon

#endif // SOLVERS_CG_SOLVER_STAGES_HPP_
#endif // SOLVERS_CG_SOLVER_HPP_
25 changes: 12 additions & 13 deletions src/solvers/mg_solver_stages.hpp → src/solvers/mg_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
// license in this material to reproduce, prepare derivative works, distribute copies to
// the public, perform publicly and display publicly, and to permit others to do so.
//========================================================================================
#ifndef SOLVERS_MG_SOLVER_STAGES_HPP_
#define SOLVERS_MG_SOLVER_STAGES_HPP_
#ifndef SOLVERS_MG_SOLVER_HPP_
#define SOLVERS_MG_SOLVER_HPP_

#include <algorithm>
#include <cstdio>
Expand Down Expand Up @@ -77,6 +77,7 @@ struct MGParams {
template <class equations_t, class prolongator_t = ProlongationBlockInteriorDefault>
class MGSolver : public SolverBase {
static inline std::size_t id{0};

public:
using FieldTL = typename equations_t::IndependentVars;

Expand All @@ -95,16 +96,14 @@ class MGSolver : public SolverBase {
std::string container_res_err, container_temp, container_u0, container_diag;

MGSolver(const std::string &container_base, const std::string &container_u,
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, equations_t eq_in = equations_t())
: MGSolver(container_base, container_u, container_rhs,
MGParams(pin, input_block), eq_in,
prolongator_t(pin, input_block)) {}
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, equations_t eq_in = equations_t())
: MGSolver(container_base, container_u, container_rhs, MGParams(pin, input_block),
eq_in, prolongator_t(pin, input_block)) {}

MGSolver(const std::string &container_base, const std::string &container_u,
const std::string &container_rhs, MGParams params_in,
equations_t eq_in = equations_t(),
prolongator_t prol_in = prolongator_t())
const std::string &container_rhs, MGParams params_in,
equations_t eq_in = equations_t(), prolongator_t prol_in = prolongator_t())
: container_base(container_base), container_u(container_u),
container_rhs(container_rhs), params_(params_in), iter_counter(0), eqs_(eq_in),
prolongator_(prol_in) {
Expand Down Expand Up @@ -292,8 +291,8 @@ class MGSolver : public SolverBase {
auto comm =
AddBoundaryExchangeTasks<comm_boundary>(depends_on, tl, md_in, multilevel);
auto mat_mult = eqs_.Ax(tl, comm, md_base, md_in, md_out);
return tl.AddTask(mat_mult, TF(&MGSolver::Jacobi), this, md_rhs, md_out,
md_diag, md_in, md_out, omega);
return tl.AddTask(mat_mult, TF(&MGSolver::Jacobi), this, md_rhs, md_out, md_diag,
md_in, md_out, omega);
}

template <parthenon::BoundaryType comm_boundary, class TL_t>
Expand Down Expand Up @@ -530,4 +529,4 @@ class MGSolver : public SolverBase {

} // namespace parthenon

#endif // SOLVERS_MG_SOLVER_STAGES_HPP_
#endif // SOLVERS_MG_SOLVER_HPP_
3 changes: 1 addition & 2 deletions src/solvers/solver_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ TaskStatus SetToZero(const std::shared_ptr<MeshData<Real>> &md) {
int nblocks = md->NumBlocks();
using TE = parthenon::TopologicalElement;
TE te = TE::CC;
static auto desc = [&]{
static auto desc = [&] {
if constexpr (isTypeList<TL>::value) {
return parthenon::MakePackDescriptorFromTypeList<TL>(md.get());
} else {
Expand Down Expand Up @@ -245,7 +245,6 @@ TaskStatus SetToZero(const std::shared_ptr<MeshData<Real>> &md) {
return TaskStatus::complete;
}


template <class a_t, class b_t, class out_t, bool only_fine_on_composite = true>
TaskStatus AddFieldsAndStoreInteriorSelect(const std::shared_ptr<MeshData<Real>> &md,
Real wa = 1.0, Real wb = 1.0,
Expand Down
6 changes: 2 additions & 4 deletions src/utils/type_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,10 @@ auto GetNames() {
}

template <class>
struct isTypeList : public std::false_type
{ };
struct isTypeList : public std::false_type {};

template <class... Ts>
struct isTypeList<TypeList<Ts...>> : public std::true_type
{ };
struct isTypeList<TypeList<Ts...>> : public std::true_type {};

} // namespace parthenon

Expand Down

0 comments on commit 4789e0e

Please sign in to comment.