Skip to content

Commit

Permalink
format and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Oct 24, 2023
1 parent afea55c commit 0fca738
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 28 deletions.
9 changes: 5 additions & 4 deletions example/poisson_gmg/poisson_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TaskCollection PoissonDriver::MakeTaskCollection(BlockList_t &blocks) {
for (int i = 0; i < num_partitions; ++i) {
TaskList &tl = region[i];
auto &md = pmesh->mesh_data.GetOrAdd("base", i);

// Possibly set rhs <- A.u_exact for a given u_exact so that the exact solution is
// known when we solve A.u = rhs
auto get_rhs = none;
Expand All @@ -94,10 +94,10 @@ TaskCollection PoissonDriver::MakeTaskCollection(BlockList_t &blocks) {
eqs.do_flux_cor = flux_correct;
get_rhs = eqs.Ax<u, rhs>(tl, comm, md);
}

// Set initial solution guess to zero
auto zero_u = tl.AddTask(get_rhs, solvers::utils::SetToZero<u>, md);

auto solve = zero_u;
auto &itl = tl.AddIteration("Solver");
if (solver == "BiCGSTAB") {
Expand All @@ -121,7 +121,8 @@ TaskCollection PoissonDriver::MakeTaskCollection(BlockList_t &blocks) {
if (partition != 0) return TaskStatus::complete;
driver->final_rms_error =
std::sqrt(driver->err.val / driver->pmesh->GetTotalCells());
if (Globals::my_rank == 0) printf("Final rms error: %e\n", driver->final_rms_error);
if (Globals::my_rank == 0)
printf("Final rms error: %e\n", driver->final_rms_error);
return TaskStatus::complete;
},
this, i);
Expand Down
15 changes: 10 additions & 5 deletions src/solvers/bicgstab_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class BiCGSTABSolver {
reg_dep_id++;
if (i == 0) {
tl.AddTask(dependence, [&]() {
if (Globals::my_rank == 0) printf("# [0] v-cycle\n# [1] rms-residual\n# [2] rms-error\n");
if (Globals::my_rank == 0)
printf("# [0] v-cycle\n# [1] rms-residual\n# [2] rms-error\n");
return TaskStatus::complete;
});
}
Expand All @@ -114,7 +115,8 @@ class BiCGSTABSolver {
if (params_.precondition) {
auto set_rhs = itl.AddTask(precon1, CopyData<p, rhs>, md);
auto zero_u = itl.AddTask(precon1, SetToZero<u>, md);
precon1 = preconditioner.AddLinearOperatorTasks(region, itl, set_rhs | zero_u, i, reg_dep_id, pmesh);
precon1 = preconditioner.AddLinearOperatorTasks(region, itl, set_rhs | zero_u, i,
reg_dep_id, pmesh);
} else {
precon1 = itl.AddTask(initialize, CopyData<p, u>, md);
}
Expand Down Expand Up @@ -153,7 +155,8 @@ class BiCGSTABSolver {
[&](BiCGSTABSolver *solver, Mesh *pmesh, int partition) {
if (partition != 0) return TaskStatus::complete;
Real rms_res = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
if (Globals::my_rank == 0) printf("%i %e\n", solver->iter_counter * 2 + 1, rms_res);
if (Globals::my_rank == 0)
printf("%i %e\n", solver->iter_counter * 2 + 1, rms_res);
return TaskStatus::complete;
},
this, pmesh, i);
Expand All @@ -163,7 +166,8 @@ class BiCGSTABSolver {
if (params_.precondition) {
auto set_rhs = itl.AddTask(precon2, CopyData<s, rhs>, md);
auto zero_u = itl.AddTask(precon2, SetToZero<u>, md);
precon2 = preconditioner.AddLinearOperatorTasks(region, itl, set_rhs | zero_u, i, reg_dep_id, pmesh);
precon2 = preconditioner.AddLinearOperatorTasks(region, itl, set_rhs | zero_u, i,
reg_dep_id, pmesh);
} else {
precon2 = itl.AddTask(precon2, CopyData<s, u>, md);
}
Expand Down Expand Up @@ -203,7 +207,8 @@ class BiCGSTABSolver {
get_res2,
[&](BiCGSTABSolver *solver, Mesh *pmesh) {
Real rms_err = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
if (Globals::my_rank == 0) printf("%i %e\n", solver->iter_counter * 2 + 2, rms_err);
if (Globals::my_rank == 0)
printf("%i %e\n", solver->iter_counter * 2 + 2, rms_err);
return TaskStatus::complete;
},
this, pmesh);
Expand Down
39 changes: 21 additions & 18 deletions src/solvers/mg_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ class MGSolver {
itl.AddTask(
dependence,
[](int partition, int *iter_counter) {
if (partition != 0 || *iter_counter > 0 || Globals::my_rank != 0) return TaskStatus::complete;
if (partition != 0 || *iter_counter > 0 || Globals::my_rank != 0)
return TaskStatus::complete;
printf("# [0] v-cycle\n# [1] rms-residual\n# [2] rms-error\n");
return TaskStatus::complete;
},
partition, &iter_counter);
auto mg_finest = AddLinearOperatorTasks(region, itl, dependence, partition, reg_dep_id, pmesh);
auto mg_finest =
AddLinearOperatorTasks(region, itl, dependence, partition, reg_dep_id, pmesh);
auto &md = pmesh->mesh_data.GetOrAdd("base", partition);
auto calc_pointwise_res = eqs_.template Ax<u, res_err>(itl, mg_finest, md);
calc_pointwise_res = itl.AddTask(
Expand Down Expand Up @@ -107,14 +109,16 @@ class MGSolver {
}

template <class TL_t>
TaskID AddLinearOperatorTasks(TaskRegion &region, TL_t &tl, TaskID dependence, int partition, int &reg_dep_id, Mesh *pmesh) {
TaskID AddLinearOperatorTasks(TaskRegion &region, TL_t &tl, TaskID dependence,
int partition, int &reg_dep_id, Mesh *pmesh) {
using namespace utils;
iter_counter = 0;

int min_level = 0;
int max_level = pmesh->GetGMGMaxLevel();

return AddMultiGridTasksPartitionLevel(region, tl, dependence, partition, reg_dep_id, max_level, min_level, max_level, pmesh);

return AddMultiGridTasksPartitionLevel(region, tl, dependence, partition, reg_dep_id,
max_level, min_level, max_level, pmesh);
}

Real GetSquaredResidualSum() const { return residual.val; }
Expand Down Expand Up @@ -199,7 +203,7 @@ class MGSolver {
{{0.8723, 0.5395, 0.0000}, {1.3895, 0.5617, 0.0000}, {1.7319, 0.5695, 0.0000}}};
std::array<std::array<Real, 3>, 3> omega_M3{
{{0.9372, 0.6667, 0.5173}, {1.6653, 0.8000, 0.5264}, {2.2473, 0.8571, 0.5296}}};

if (stages == 0) return depends_on;
auto omega = omega_M1;
if (stages == 2) omega = omega_M2;
Expand All @@ -221,17 +225,16 @@ class MGSolver {
}

template <class TL_t>
TaskID AddMultiGridTasksPartitionLevel(TaskRegion &region, TL_t &tl, TaskID dependence,
int partition, int &reg_dep_id,
int level, int min_level, int max_level,
Mesh *pmesh) {
TaskID AddMultiGridTasksPartitionLevel(TaskRegion &region, TL_t &tl, TaskID dependence,
int partition, int &reg_dep_id, int level,
int min_level, int max_level, Mesh *pmesh) {
using namespace utils;
auto smoother = params_.smoother;
bool do_FAS = params_.do_FAS;
int pre_stages, post_stages;
if (smoother == "none") {
if (smoother == "none") {
pre_stages = 0;
post_stages = 0;
post_stages = 0;
} else if (smoother == "SRJ1") {
pre_stages = 1;
post_stages = 1;
Expand All @@ -246,7 +249,6 @@ class MGSolver {
}

bool multilevel = (level != min_level);


auto &md = pmesh->gmg_mesh_data[level].GetOrAdd(level, "base", partition);

Expand Down Expand Up @@ -304,13 +306,14 @@ class MGSolver {
// 5. Restrict communication field and send to next level
auto communicate_to_coarse =
tl.AddTask(residual, SendBoundBufs<BoundaryType::gmg_restrict_send>, md);

auto coarser = AddMultiGridTasksPartitionLevel(region, tl, communicate_to_coarse, partition,
reg_dep_id, level - 1, min_level, max_level, pmesh);

auto coarser = AddMultiGridTasksPartitionLevel(region, tl, communicate_to_coarse,
partition, reg_dep_id, level - 1,
min_level, max_level, pmesh);

// 6. Receive error field into communication field and prolongate
auto recv_from_coarser = tl.AddTask(
coarser, ReceiveBoundBufs<BoundaryType::gmg_prolongate_recv>, md);
auto recv_from_coarser =
tl.AddTask(coarser, ReceiveBoundBufs<BoundaryType::gmg_prolongate_recv>, md);
auto set_from_coarser =
tl.AddTask(recv_from_coarser, SetBounds<BoundaryType::gmg_prolongate_recv>, md);
auto prolongate = tl.AddTask(
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/solver_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ TaskStatus AddFieldsAndStoreInteriorSelect(const std::shared_ptr<MeshData<Real>>
IndexRange ib = md->GetBoundsI(IndexDomain::entire, te);
IndexRange jb = md->GetBoundsJ(IndexDomain::entire, te);
IndexRange kb = md->GetBoundsK(IndexDomain::entire, te);

int nblocks = md->NumBlocks();
std::vector<bool> include_block(nblocks, true);
if (only_interior) {
Expand Down

0 comments on commit 0fca738

Please sign in to comment.