Skip to content

Commit

Permalink
closer to resolving all stages
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Nov 21, 2024
1 parent ca1f1f5 commit 37500d2
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 49 deletions.
2 changes: 1 addition & 1 deletion doc/sphinx/src/solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ code along the lines of:
std::string rhs_cont_name = "rhs";

MySystemOfEquations eqs(....);
std::shared_ptr<SolverBase> psolver = std::make_shared<BiCGSTABSolverStages<MySystemOfEquations>>(
std::shared_ptr<SolverBase> psolver = std::make_shared<BiCGSTABSolver<MySystemOfEquations>>(
base_cont_name, u_cont_name, rhs_cont_name, pin, "location/of/solver_params", eqs);

...
Expand Down
10 changes: 7 additions & 3 deletions example/poisson_gmg/poisson_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ TaskCollection PoissonDriver::MakeTaskCollection(BlockList_t &blocks) {
// known when we solve A.u = rhs
if (use_exact_rhs) {
auto copy_exact = tl.AddTask(copy_rhs, TF(solvers::utils::CopyData<exact, u>), md);
copy_exact = tl.AddTask(
copy_rhs, TF(solvers::utils::CopyData<parthenon::TypeList<u>>), md, md_u);
auto comm = AddBoundaryExchangeTasks<BoundaryType::any>(copy_exact, tl, md_u, true);
auto *eqs = pkg->MutableParam<poisson_package::PoissonEquationStages<u, D>>("poisson_equation");
copy_rhs = eqs->Ax(tl, comm, md, md, md_rhs);
auto *eqs = pkg->MutableParam<poisson_package::PoissonEquation<u, D>>("poisson_equation");
copy_rhs = eqs->Ax(tl, comm, md, md_u, md_rhs);
}

// Set initial solution guess to zero
Expand All @@ -97,7 +99,9 @@ TaskCollection PoissonDriver::MakeTaskCollection(BlockList_t &blocks) {
// If we are using a rhs to which we know the exact solution, compare our computed
// solution to the exact solution
if (use_exact_rhs) {
auto diff = tl.AddTask(solve, TF(solvers::utils::AddFieldsAndStore<exact, u, u>),
auto copy_back = tl.AddTask(
solve, TF(solvers::utils::CopyData<parthenon::TypeList<u>>), md_u, md);
auto diff = tl.AddTask(copy_back, TF(solvers::utils::AddFieldsAndStore<exact, u, u>),
md, 1.0, -1.0);
auto get_err = solvers::utils::DotProduct<u, u>(diff, tl, &err, md);
tl.AddTask(
Expand Down
4 changes: 2 additions & 2 deletions example/poisson_gmg/poisson_equation_stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ namespace poisson_package {
// are internal, but can't be marked private or protected because they launch kernels
// on device.
template <class var_t, class D_t>
class PoissonEquationStages {
class PoissonEquation {
public:
bool do_flux_cor = false;
bool set_flux_boundary = false;
bool include_flux_dx = false;

using IndependentVars = parthenon::TypeList<var_t>;

PoissonEquationStages(parthenon::ParameterInput *pin, const std::string &label) {
PoissonEquation(parthenon::ParameterInput *pin, const std::string &label) {
do_flux_cor = pin->GetOrAddBoolean(label, "flux_correct", false);
set_flux_boundary = pin->GetOrAddBoolean(label, "set_flux_boundary", false);
include_flux_dx =
Expand Down
24 changes: 12 additions & 12 deletions example/poisson_gmg/poisson_package.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,26 @@ std::shared_ptr<StateDescriptor> Initialize(ParameterInput *pin) {

std::string prolong = pin->GetOrAddString("poisson", "boundary_prolongation", "Linear");

using PoissEqStages = poisson_package::PoissonEquationStages<u, D>;
PoissEqStages eq(pin, "poisson");
using PoissEq = poisson_package::PoissonEquation<u, D>;
PoissEq eq(pin, "poisson");
pkg->AddParam<>("poisson_equation", eq, parthenon::Params::Mutability::Mutable);

std::shared_ptr<parthenon::solvers::SolverBase> psolver;
using prolongator_t = parthenon::solvers::ProlongationBlockInteriorDefault;
using preconditioner_t =
parthenon::solvers::MGSolverStages<PoissEqStages, prolongator_t>;
if (solver == "MGStages") {
parthenon::solvers::MGSolver<PoissEq, prolongator_t>;
if (solver == "MG") {
psolver = std::make_shared<
parthenon::solvers::MGSolverStages<PoissEqStages, prolongator_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEqStages(pin, "poisson"));
} else if (solver == "CGStages") {
parthenon::solvers::MGSolver<PoissEq, prolongator_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEq(pin, "poisson"));
} else if (solver == "CG") {
psolver = std::make_shared<
parthenon::solvers::CGSolverStages<PoissEqStages, preconditioner_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEqStages(pin, "poisson"));
} else if (solver == "BiCGSTABStages") {
parthenon::solvers::CGSolver<PoissEq, preconditioner_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEq(pin, "poisson"));
} else if (solver == "BiCGSTAB") {
psolver = std::make_shared<
parthenon::solvers::BiCGSTABSolverStages<PoissEqStages, preconditioner_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEqStages(pin, "poisson"));
parthenon::solvers::BiCGSTABSolver<PoissEq, preconditioner_t>>(
"base", "u", "rhs", pin, "poisson/solver_params", PoissEq(pin, "poisson"));
} else {
PARTHENON_FAIL("Unknown solver type.");
}
Expand Down
28 changes: 14 additions & 14 deletions src/solvers/bicgstab_solver_stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ struct BiCGSTABParams {
//
// that takes a field associated with x_t and applies
// the matrix A to it and stores the result in y_t.
template <class equations, class preconditioner_t = MGSolverStages<equations>>
class BiCGSTABSolverStages : public SolverBase {
template <class equations, class preconditioner_t = MGSolver<equations>>
class BiCGSTABSolver : public SolverBase {
using FieldTL = typename equations::IndependentVars;

std::vector<std::string> sol_fields;
Expand All @@ -90,7 +90,7 @@ class BiCGSTABSolverStages : public SolverBase {

static inline std::size_t id{0};
public:
BiCGSTABSolverStages(const std::string &container_base, const std::string &container_u,
BiCGSTABSolver(const std::string &container_base, const std::string &container_u,
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, equations eq_in = equations())
: preconditioner(container_base, container_u, container_rhs, pin, input_block,
Expand Down Expand Up @@ -167,14 +167,14 @@ class BiCGSTABSolverStages : public SolverBase {
TaskQualifier::once_per_region | TaskQualifier::local_sync,
zero_x | zero_u_init | copy_r | copy_p | copy_rhat0 | get_rhat0r_init | get_rhs2,
"zero factors",
[](BiCGSTABSolverStages *solver) {
[](BiCGSTABSolver *solver) {
solver->iter_counter = -1;
return TaskStatus::complete;
},
this);
tl.AddTask(
TaskQualifier::once_per_region, initialize, "print to screen",
[&](BiCGSTABSolverStages *solver, std::shared_ptr<Real> res_tol,
[&](BiCGSTABSolver *solver, std::shared_ptr<Real> res_tol,
bool relative_residual, Mesh *pm) {
if (Globals::my_rank == 0 && params_.print_per_step) {
Real tol = relative_residual
Expand All @@ -195,7 +195,7 @@ class BiCGSTABSolverStages : public SolverBase {
[]() { return TaskStatus::complete; });
auto reset = itl.AddTask(
TaskQualifier::once_per_region, sync, "update values",
[](BiCGSTABSolverStages *solver) {
[](BiCGSTABSolver *solver) {
solver->rhat0r_old = solver->rhat0r.val;
solver->iter_counter++;
return TaskStatus::complete;
Expand Down Expand Up @@ -226,7 +226,7 @@ class BiCGSTABSolverStages : public SolverBase {
// 4. h <- x + alpha u (alpha = rhat0r_old / rhat0v)
auto correct_h = itl.AddTask(
get_rhat0v, "h <- x + alpha u",
[](BiCGSTABSolverStages *solver, std::shared_ptr<MeshData<Real>> &md_x,
[](BiCGSTABSolver *solver, std::shared_ptr<MeshData<Real>> &md_x,
std::shared_ptr<MeshData<Real>> &md_u, std::shared_ptr<MeshData<Real>> &md_h) {
Real alpha = solver->rhat0r_old / solver->rhat0v.val;
return AddFieldsAndStore<FieldTL>(md_x, md_u, md_h, 1.0, alpha);
Expand All @@ -236,7 +236,7 @@ class BiCGSTABSolverStages : public SolverBase {
// 5. s <- r - alpha v (alpha = rhat0r_old / rhat0v)
auto correct_s = itl.AddTask(
get_rhat0v, "s <- r - alpha v",
[](BiCGSTABSolverStages *solver, std::shared_ptr<MeshData<Real>> &md_r,
[](BiCGSTABSolver *solver, std::shared_ptr<MeshData<Real>> &md_r,
std::shared_ptr<MeshData<Real>> &md_v, std::shared_ptr<MeshData<Real>> &md_s) {
Real alpha = solver->rhat0r_old / solver->rhat0v.val;
return AddFieldsAndStore<FieldTL>(md_r, md_v, md_s, 1.0, -alpha);
Expand All @@ -248,7 +248,7 @@ class BiCGSTABSolverStages : public SolverBase {

auto print = itl.AddTask(
TaskQualifier::once_per_region, get_res,
[&](BiCGSTABSolverStages *solver, Mesh *pmesh) {
[&](BiCGSTABSolver *solver, Mesh *pmesh) {
Real rms_res = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
if (Globals::my_rank == 0 && solver->params_.print_per_step)
printf("%i %e\n", solver->iter_counter * 2 + 1, rms_res);
Expand Down Expand Up @@ -281,7 +281,7 @@ class BiCGSTABSolverStages : public SolverBase {
// 9. x <- h + omega u
auto correct_x = itl.AddTask(
get_tt | get_ts, "x <- h + omega u",
[](BiCGSTABSolverStages *solver, std::shared_ptr<MeshData<Real>> &md_h,
[](BiCGSTABSolver *solver, std::shared_ptr<MeshData<Real>> &md_h,
std::shared_ptr<MeshData<Real>> &md_u, std::shared_ptr<MeshData<Real>> &md_x) {
Real omega = solver->ts.val / solver->tt.val;
return AddFieldsAndStore<FieldTL>(md_h, md_u, md_x, 1.0, omega);
Expand All @@ -291,7 +291,7 @@ class BiCGSTABSolverStages : public SolverBase {
// 10. r <- s - omega t
auto correct_r = itl.AddTask(
get_tt | get_ts, "r <- s - omega t",
[](BiCGSTABSolverStages *solver, std::shared_ptr<MeshData<Real>> &md_s,
[](BiCGSTABSolver *solver, std::shared_ptr<MeshData<Real>> &md_s,
std::shared_ptr<MeshData<Real>> &md_t, std::shared_ptr<MeshData<Real>> &md_r) {
Real omega = solver->ts.val / solver->tt.val;
return AddFieldsAndStore<FieldTL>(md_s, md_t, md_r, 1.0, -omega);
Expand All @@ -303,7 +303,7 @@ class BiCGSTABSolverStages : public SolverBase {

get_res2 = itl.AddTask(
TaskQualifier::once_per_region, get_res2,
[&](BiCGSTABSolverStages *solver, Mesh *pmesh) {
[&](BiCGSTABSolver *solver, Mesh *pmesh) {
Real rms_err = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
if (Globals::my_rank == 0 && solver->params_.print_per_step)
printf("%i %e\n", solver->iter_counter * 2 + 2, rms_err);
Expand All @@ -318,7 +318,7 @@ class BiCGSTABSolverStages : public SolverBase {
// 13. p <- r + beta * (p - omega * v)
auto update_p = itl.AddTask(
get_rhat0r | get_res2, "p <- r + beta * (p - omega * v)",
[](BiCGSTABSolverStages *solver, std::shared_ptr<MeshData<Real>> &md_p,
[](BiCGSTABSolver *solver, std::shared_ptr<MeshData<Real>> &md_p,
std::shared_ptr<MeshData<Real>> &md_v, std::shared_ptr<MeshData<Real>> &md_r) {
Real alpha = solver->rhat0r_old / solver->rhat0v.val;
Real omega = solver->ts.val / solver->tt.val;
Expand All @@ -332,7 +332,7 @@ class BiCGSTABSolverStages : public SolverBase {
// 14. rhat0r_old <- rhat0r, zero all reductions
auto check = itl.AddTask(
TaskQualifier::completion, update_p | correct_x, "rhat0r_old <- rhat0r",
[partition](BiCGSTABSolverStages *solver, Mesh *pmesh, int max_iter,
[partition](BiCGSTABSolver *solver, Mesh *pmesh, int max_iter,
std::shared_ptr<Real> res_tol, bool relative_residual) {
Real rms_res = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
solver->final_residual = rms_res;
Expand Down
22 changes: 11 additions & 11 deletions src/solvers/cg_solver_stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ struct CGParams {
//
// that takes a field associated with x_t and applies
// the matrix A to it and stores the result in y_t.
template <class equations, class preconditioner_t = MGSolverStages<equations>>
class CGSolverStages : public SolverBase {
template <class equations, class preconditioner_t = MGSolver<equations>>
class CGSolver : public SolverBase {
using FieldTL = typename equations::IndependentVars;

std::vector<std::string> sol_fields;
Expand All @@ -80,7 +80,7 @@ class CGSolverStages : public SolverBase {

static inline std::size_t id{0};
public:
CGSolverStages(const std::string &container_base, const std::string &container_u,
CGSolver(const std::string &container_base, const std::string &container_u,
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, const equations &eq_in = equations())
: preconditioner(container_base, container_u, container_rhs, pin, input_block,
Expand Down Expand Up @@ -136,7 +136,7 @@ class CGSolverStages : public SolverBase {
auto initialize = tl.AddTask(
TaskQualifier::once_per_region | TaskQualifier::local_sync,
zero_u | zero_v | zero_x | zero_p | copy_r | get_rhs2, "zero factors",
[](CGSolverStages *solver) {
[](CGSolver *solver) {
solver->iter_counter = -1;
solver->ru.val = std::numeric_limits<Real>::max();
return TaskStatus::complete;
Expand All @@ -146,7 +146,7 @@ class CGSolverStages : public SolverBase {
if (params_.print_per_step && Globals::my_rank == 0) {
initialize = tl.AddTask(
TaskQualifier::once_per_region, initialize, "print to screen",
[&](CGSolverStages *solver, std::shared_ptr<Real> res_tol,
[&](CGSolver *solver, std::shared_ptr<Real> res_tol,
bool relative_residual, Mesh *pm) {
Real tol = relative_residual
? *res_tol * std::sqrt(solver->rhs2.val / pm->GetTotalCells())
Expand All @@ -166,7 +166,7 @@ class CGSolverStages : public SolverBase {
[]() { return TaskStatus::complete; });
auto reset = itl.AddTask(
TaskQualifier::once_per_region, sync, "update values",
[](CGSolverStages *solver) {
[](CGSolver *solver) {
solver->ru_old = solver->ru.val;
solver->iter_counter++;
return TaskStatus::complete;
Expand All @@ -190,7 +190,7 @@ class CGSolverStages : public SolverBase {
// 3. p <- u + beta p
auto correct_p = itl.AddTask(
get_ru, "p <- u + beta p",
[](CGSolverStages *solver, std::shared_ptr<MeshData<Real>> &md_u,
[](CGSolver *solver, std::shared_ptr<MeshData<Real>> &md_u,
std::shared_ptr<MeshData<Real>> &md_p) {
Real beta = solver->iter_counter > 0 ? solver->ru.val / solver->ru_old : 0.0;
return AddFieldsAndStore<FieldTL>(md_u, md_p, md_p, 1.0, beta);
Expand All @@ -208,7 +208,7 @@ class CGSolverStages : public SolverBase {
// 6. x <- x + alpha p
auto correct_x = itl.AddTask(
get_pAp, "x <- x + alpha p",
[](CGSolverStages *solver, std::shared_ptr<MeshData<Real>> &md_x,
[](CGSolver *solver, std::shared_ptr<MeshData<Real>> &md_x,
std::shared_ptr<MeshData<Real>> &md_p) {
Real alpha = solver->ru.val / solver->pAp.val;
return AddFieldsAndStore<FieldTL>(md_x, md_p, md_x, 1.0, alpha);
Expand All @@ -218,7 +218,7 @@ class CGSolverStages : public SolverBase {
// 6. r <- r - alpha A p
auto correct_r = itl.AddTask(
get_pAp, "r <- r - alpha A p",
[](CGSolverStages *solver, std::shared_ptr<MeshData<Real>> &md_r,
[](CGSolver *solver, std::shared_ptr<MeshData<Real>> &md_r,
std::shared_ptr<MeshData<Real>> &md_v) {
Real alpha = solver->ru.val / solver->pAp.val;
return AddFieldsAndStore<FieldTL>(md_r, md_v, md_r, 1.0, -alpha);
Expand All @@ -230,7 +230,7 @@ class CGSolverStages : public SolverBase {

auto print = itl.AddTask(
TaskQualifier::once_per_region, get_res,
[&](CGSolverStages *solver, Mesh *pmesh) {
[&](CGSolver *solver, Mesh *pmesh) {
Real rms_res = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
if (Globals::my_rank == 0 && solver->params_.print_per_step)
printf("%i %e\n", solver->iter_counter, rms_res);
Expand All @@ -240,7 +240,7 @@ class CGSolverStages : public SolverBase {

auto check = itl.AddTask(
TaskQualifier::completion, get_res | correct_x, "completion",
[](CGSolverStages *solver, Mesh *pmesh, int max_iter,
[](CGSolver *solver, Mesh *pmesh, int max_iter,
std::shared_ptr<Real> res_tol, bool relative_residual) {
Real rms_res = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
solver->final_residual = rms_res;
Expand Down
12 changes: 6 additions & 6 deletions src/solvers/mg_solver_stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct MGParams {
// That stores the (possibly approximate) diagonal of matrix A in the field
// associated with the type diag_t. This is used for Jacobi iteration.
template <class equations_t, class prolongator_t = ProlongationBlockInteriorDefault>
class MGSolverStages : public SolverBase {
class MGSolver : public SolverBase {
static inline std::size_t id{0};
public:
using FieldTL = typename equations_t::IndependentVars;
Expand All @@ -94,14 +94,14 @@ class MGSolverStages : public SolverBase {
// Internal containers for solver which create deep copies of sol_fields
std::string container_res_err, container_temp, container_u0, container_diag;

MGSolverStages(const std::string &container_base, const std::string &container_u,
MGSolver(const std::string &container_base, const std::string &container_u,
const std::string &container_rhs, ParameterInput *pin,
const std::string &input_block, equations_t eq_in = equations_t())
: MGSolverStages(container_base, container_u, container_rhs,
: MGSolver(container_base, container_u, container_rhs,
MGParams(pin, input_block), eq_in,
prolongator_t(pin, input_block)) {}

MGSolverStages(const std::string &container_base, const std::string &container_u,
MGSolver(const std::string &container_base, const std::string &container_u,
const std::string &container_rhs, MGParams params_in,
equations_t eq_in = equations_t(),
prolongator_t prol_in = prolongator_t())
Expand Down Expand Up @@ -151,7 +151,7 @@ class MGSolverStages : public SolverBase {

auto check = itl.AddTask(
TaskQualifier::completion, get_res, "Check residual",
[partition](MGSolverStages *solver, Mesh *pmesh) {
[partition](MGSolver *solver, Mesh *pmesh) {
Real rms_res = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
if (Globals::my_rank == 0 && partition == 0)
printf("%i %e\n", solver->iter_counter, rms_res);
Expand Down Expand Up @@ -292,7 +292,7 @@ class MGSolverStages : public SolverBase {
auto comm =
AddBoundaryExchangeTasks<comm_boundary>(depends_on, tl, md_in, multilevel);
auto mat_mult = eqs_.Ax(tl, comm, md_base, md_in, md_out);
return tl.AddTask(mat_mult, TF(&MGSolverStages::Jacobi), this, md_rhs, md_out,
return tl.AddTask(mat_mult, TF(&MGSolver::Jacobi), this, md_rhs, md_out,
md_diag, md_in, md_out, omega);
}

Expand Down

0 comments on commit 37500d2

Please sign in to comment.