Skip to content

Commit

Permalink
Allow for multiple equations
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Oct 25, 2023
1 parent 0fca738 commit 0dd89e1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
8 changes: 4 additions & 4 deletions src/solvers/bicgstab_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ class BiCGSTABSolver {
INTERNALSOLVERVARIABLE(x, u);

BiCGSTABSolver(StateDescriptor *pkg, BiCGSTABParams params_in,
equations eq_in = equations())
: preconditioner(pkg, MGParams(), eq_in), params_(params_in), iter_counter(0),
equations eq_in = equations(), std::vector<int> shape = {})
: preconditioner(pkg, MGParams(), eq_in, shape), params_(params_in), iter_counter(0),
eqs_(eq_in) {
using namespace refinement_ops;
auto mu = Metadata({Metadata::Cell, Metadata::Independent, Metadata::FillGhost,
Metadata::WithFluxes, Metadata::GMGRestrict});
Metadata::WithFluxes, Metadata::GMGRestrict}, shape);
mu.RegisterRefinementOps<ProlongateSharedLinear, RestrictAverage>();
auto m_no_ghost = Metadata({Metadata::Cell, Metadata::Derived, Metadata::OneCopy});
auto m_no_ghost = Metadata({Metadata::Cell, Metadata::Derived, Metadata::OneCopy}, shape);
pkg->AddField(u::name(), mu);
pkg->AddField(rhat0::name(), m_no_ghost);
pkg->AddField(v::name(), m_no_ghost);
Expand Down
31 changes: 17 additions & 14 deletions src/solvers/mg_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,21 @@ class MGSolver {
INTERNALSOLVERVARIABLE(u, u0); // Storage for initial solution during FAS
INTERNALSOLVERVARIABLE(u, D); // Storage for (approximate) diagonal

MGSolver(StateDescriptor *pkg, MGParams params_in, equations eq_in = equations())
MGSolver(StateDescriptor *pkg, MGParams params_in, equations eq_in = equations(), std::vector<int> shape = {})
: params_(params_in), iter_counter(0), eqs_(eq_in) {
using namespace parthenon::refinement_ops;
auto mres_err =
Metadata({Metadata::Cell, Metadata::Independent, Metadata::FillGhost,
Metadata::GMGRestrict, Metadata::GMGProlongate, Metadata::OneCopy});
Metadata::GMGRestrict, Metadata::GMGProlongate, Metadata::OneCopy}, shape);
mres_err.RegisterRefinementOps<ProlongateSharedLinear, RestrictAverage>();
pkg->AddField(res_err::name(), mres_err);

auto mtemp = Metadata({Metadata::Cell, Metadata::Independent, Metadata::FillGhost,
Metadata::WithFluxes, Metadata::OneCopy});
Metadata::WithFluxes, Metadata::OneCopy}, shape);
mtemp.RegisterRefinementOps<ProlongateSharedLinear, RestrictAverage>();
pkg->AddField(temp::name(), mtemp);

auto mu0 = Metadata({Metadata::Cell, Metadata::Derived, Metadata::OneCopy});
auto mu0 = Metadata({Metadata::Cell, Metadata::Derived, Metadata::OneCopy}, shape);
pkg->AddField(u0::name(), mu0);
pkg->AddField(D::name(), mu0);
}
Expand Down Expand Up @@ -165,16 +165,19 @@ class MGSolver {
const auto &coords = pack.GetCoordinates(b);
if ((i + j + k) % 2 == 1 && gs_type == GSType::red) return;
if ((i + j + k) % 2 == 0 && gs_type == GSType::black) return;

Real diag_elem = pack(b, te, D_t(), k, j, i);

// Get the off-diagonal contribution to Ax = (D + L + U)x = y
Real off_diag = pack(b, te, Axold_t(), k, j, i) -
diag_elem * pack(b, te, xold_t(), k, j, i);

Real val = pack(b, te, rhs_t(), k, j, i) - off_diag;
pack(b, te, xnew_t(), k, j, i) =
weight * val / diag_elem + (1.0 - weight) * pack(b, te, xold_t(), k, j, i);

const int nvars = pack.GetUpperBound(b, D_t()) - pack.GetLowerBound(b, D_t()) + 1;
for (int c = 0; c < nvars; ++c) {
Real diag_elem = pack(b, te, D_t(c), k, j, i);

// Get the off-diagonal contribution to Ax = (D + L + U)x = y
Real off_diag = pack(b, te, Axold_t(c), k, j, i) -
diag_elem * pack(b, te, xold_t(c), k, j, i);

Real val = pack(b, te, rhs_t(c), k, j, i) - off_diag;
pack(b, te, xnew_t(c), k, j, i) =
weight * val / diag_elem + (1.0 - weight) * pack(b, te, xold_t(c), k, j, i);
}
});
return TaskStatus::complete;
}
Expand Down
21 changes: 16 additions & 5 deletions src/solvers/solver_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ TaskStatus CopyData(const std::shared_ptr<MeshData<Real>> &md) {
DEFAULT_LOOP_PATTERN, "SetPotentialToZero", DevExecSpace(), 0,
pack.GetNBlocks() - 1, kb.s, kb.e, jb.s, jb.e, ib.s, ib.e,
KOKKOS_LAMBDA(const int b, const int k, const int j, const int i) {
pack(b, te, out(), k, j, i) = pack(b, te, in(), k, j, i);
const int nvars = pack.GetUpperBound(b, in()) - pack.GetLowerBound(b, in()) + 1;
for (int c = 0; c < nvars; ++c)
pack(b, te, out(c), k, j, i) = pack(b, te, in(c), k, j, i);
});
return TaskStatus::complete;
}
Expand Down Expand Up @@ -189,8 +191,11 @@ TaskStatus AddFieldsAndStoreInteriorSelect(const std::shared_ptr<MeshData<Real>>
DEFAULT_LOOP_PATTERN, "SetPotentialToZero", DevExecSpace(), 0,
pack.GetNBlocks() - 1, kb.s, kb.e, jb.s, jb.e, ib.s, ib.e,
KOKKOS_LAMBDA(const int b, const int k, const int j, const int i) {
pack(b, te, out(), k, j, i) =
wa * pack(b, te, a_t(), k, j, i) + wb * pack(b, te, b_t(), k, j, i);
const int nvars = pack.GetUpperBound(b, a_t()) - pack.GetLowerBound(b, a_t()) + 1;
for (int c = 0; c < nvars; ++c) {
pack(b, te, out(c), k, j, i) =
wa * pack(b, te, a_t(c), k, j, i) + wb * pack(b, te, b_t(c), k, j, i);
}
});
return TaskStatus::complete;
}
Expand Down Expand Up @@ -224,7 +229,11 @@ TaskStatus SetToZero(const std::shared_ptr<MeshData<Real>> &md) {
IndexRange kb = cb.GetBoundsK(IndexDomain::interior, te);
parthenon::par_for_inner(
parthenon::inner_loop_pattern_simdfor_tag, member, kb.s, kb.e, jb.s, jb.e,
ib.s, ib.e, [&](int k, int j, int i) { pack(b, te, var(), k, j, i) = 0.0; });
ib.s, ib.e, [&](int k, int j, int i) {
const int nvars = pack.GetUpperBound(b, var()) - pack.GetLowerBound(b, var()) + 1;
for (int c = 0; c < nvars; ++c)
pack(b, te, var(c), k, j, i) = 0.0;
});
});
return TaskStatus::complete;
}
Expand All @@ -245,7 +254,9 @@ TaskStatus DotProductLocal(const std::shared_ptr<MeshData<Real>> &md,
parthenon::loop_pattern_mdrange_tag, "DotProduct", DevExecSpace(), 0,
pack.GetNBlocks() - 1, kb.s, kb.e, jb.s, jb.e, ib.s, ib.e,
KOKKOS_LAMBDA(const int b, const int k, const int j, const int i, Real &lsum) {
lsum += pack(b, te, a_t(), k, j, i) * pack(b, te, b_t(), k, j, i);
const int nvars = pack.GetUpperBound(b, a_t()) - pack.GetLowerBound(b, a_t()) + 1;
for (int c = 0; c < nvars; ++c)
lsum += pack(b, te, a_t(c), k, j, i) * pack(b, te, b_t(c), k, j, i);
},
Kokkos::Sum<Real>(gsum));
adotb->val += gsum;
Expand Down

0 comments on commit 0dd89e1

Please sign in to comment.