Skip to content

Commit

Permalink
Allow derivimplicit to use finite differences.
Browse files Browse the repository at this point in the history
In NOCMODL `derivimplicit` uses finite differences to compute each
element of the Jacobian. In stead we try to use SymPy, however, if it
fails, e.g. because it encounters an opaque function, we allow it to use
a finite difference instead.
  • Loading branch information
1uc committed Sep 20, 2024
1 parent 3584a03 commit 29e6b6c
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 30 deletions.
68 changes: 60 additions & 8 deletions python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,55 @@ def solve_lin_system(
return code, new_local_vars


def finite_difference_step_variable(sym):
return f"{sym}_delta_"


def discretize_derivative(expr):
if isinstance(expr, sp.Derivative):
x = expr.args[1][0]
dx = sp.symbols(finite_difference_step_variable(x))
return expr.as_finite_difference(dx)
else:
return expr


def transform_expression(expr, transform):
if expr.args is tuple():
return expr

args = list(transform_expression(transform(arg), transform) for arg in expr.args)
return expr.func(*args)


def transform_matrix_elements(mat, transform):
return sp.Matrix(
[
[transform_expression(mat[i, j], transform) for j in range(mat.rows)]
for i in range(mat.cols)
]
)


def finite_difference_variables(mat):
vars = []

def recurse(expr):
for arg in expr.args:
if isinstance(arg, sp.Derivative):
var = arg.args[1][0]
vars.append((var, finite_difference_step_variable(var)))

for expr in mat:
recurse(expr)

return vars


def needs_finite_differences(mat):
return any(isinstance(expr, sp.Derivative) for expr in sp.preorder_traversal(mat))


def solve_non_lin_system(eq_strings, vars, constants, function_calls):
"""Solve non-linear system of equations, return solution as C code.
Expand All @@ -369,28 +418,31 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
custom_fcts = _get_custom_functions(function_calls)

jacobian = sp.Matrix(eqs).jacobian(state_vars)
if needs_finite_differences(jacobian):
jacobian = transform_matrix_elements(jacobian, discretize_derivative)

X_vec_map = {x: sp.symbols(f"X[{i}]") for i, x in enumerate(state_vars)}
dX_vec_map = {
finite_difference_step_variable(x): sp.symbols(f"dX_[{i}]")
for i, x in enumerate(state_vars)
}

vecFcode = []
for i, eq in enumerate(eqs):
vecFcode.append(
f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}"
)
expr = eq.simplify().subs(X_vec_map).evalf()
rhs = sp.ccode(expr, user_functions=custom_fcts)
vecFcode.append(f"F[{i}] = {rhs}")

vecJcode = []
for i, j in itertools.product(range(jacobian.rows), range(jacobian.cols)):
flat_index = i + jacobian.rows * j

rhs = sp.ccode(
jacobian[i, j].simplify().subs(X_vec_map).evalf(),
user_functions=custom_fcts,
)
Jij = jacobian[i, j].simplify().subs({**X_vec_map, **dX_vec_map}).evalf()
rhs = sp.ccode(Jij, user_functions=custom_fcts)
vecJcode.append(f"J[{flat_index}] = {rhs}")

# interweave
code = _interweave_eqs(vecFcode, vecJcode)

code = search_and_replace_protected_identifiers_from_sympy(code, function_calls)

return code
Expand Down
8 changes: 8 additions & 0 deletions src/codegen/codegen_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,15 +708,23 @@ void CodegenCppVisitor::print_functor_definition(const ast::EigenNewtonSolverBlo

printer->fmt_text(
"void operator()(const Eigen::Matrix<{0}, {1}, 1>& nmodl_eigen_xm, Eigen::Matrix<{0}, {1}, "
"1>& nmodl_eigen_dxm, Eigen::Matrix<{0}, {1}, "
"1>& nmodl_eigen_fm, "
"Eigen::Matrix<{0}, {1}, {1}>& nmodl_eigen_jm) {2}",
float_type,
N,
is_functor_const(variable_block, functor_block) ? "const " : "");
printer->push_block();
printer->fmt_line("const {}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
printer->fmt_line("{}* nmodl_eigen_dx = nmodl_eigen_dxm.data();", float_type);
printer->fmt_line("{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type);
printer->fmt_line("{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type);

for (size_t i = 0; i < N; ++i) {
printer->fmt_line(
"nmodl_eigen_dx[{0}] = std::max(1e-6, 0.02*std::fabs(nmodl_eigen_x[{0}]));", i);
}

print_statement_block(functor_block, false, false);
printer->pop_block();
printer->add_newline();
Expand Down
7 changes: 3 additions & 4 deletions src/pybind/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@ std::tuple<std::vector<std::string>, std::string> call_solve_nonlinear_system(
exception_message = ""
try:
solutions = solve_non_lin_system(equation_strings,
state_vars,
vars,
function_calls)
state_vars,
vars,
function_calls)
except Exception as e:
# if we fail, fail silently and return empty string
import traceback
solutions = [""]
new_local_vars = [""]
exception_message = traceback.format_exc()
)";

Expand Down
7 changes: 5 additions & 2 deletions src/solver/newton/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ EIGEN_DEVICE_FUNC int newton_solver(Eigen::Matrix<double, N, 1>& X,
FUNC functor,
double eps = EPS,
int max_iter = MAX_ITER) {
// If finite differences are needed, this is stores the stepwidth.
Eigen::Matrix<double, N, 1> dX;
// Vector to store result of function F(X):
Eigen::Matrix<double, N, 1> F;
// Matrix to store Jacobian of F(X):
Expand All @@ -89,7 +91,7 @@ EIGEN_DEVICE_FUNC int newton_solver(Eigen::Matrix<double, N, 1>& X,
int iter = -1;
while (++iter < max_iter) {
// calculate F, J from X using user-supplied functor
functor(X, F, J);
functor(X, dX, F, J);
if (is_converged(X, J, F, eps)) {
return iter;
}
Expand Down Expand Up @@ -127,10 +129,11 @@ EIGEN_DEVICE_FUNC int newton_solver_small_N(Eigen::Matrix<double, N, 1>& X,
int max_iter) {
bool invertible;
Eigen::Matrix<double, N, 1> F;
Eigen::Matrix<double, N, 1> dX;
Eigen::Matrix<double, N, N> J, J_inv;
int iter = -1;
while (++iter < max_iter) {
functor(X, F, J);
functor(X, dX, F, J);
if (is_converged(X, J, F, eps)) {
return iter;
}
Expand Down
19 changes: 11 additions & 8 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ void SympySolverVisitor::construct_eigen_solver_block(
const std::vector<std::string>& solutions,
bool linear) {
auto solutions_filtered = filter_string_vector(solutions, "X[", "nmodl_eigen_x[");
solutions_filtered = filter_string_vector(solutions_filtered, "dX_[", "nmodl_eigen_dx[");
solutions_filtered = filter_string_vector(solutions_filtered, "J[", "nmodl_eigen_j[");
solutions_filtered = filter_string_vector(solutions_filtered, "Jm[", "nmodl_eigen_jm[");
solutions_filtered = filter_string_vector(solutions_filtered, "F[", "nmodl_eigen_f[");
Expand All @@ -187,16 +188,19 @@ void SympySolverVisitor::construct_eigen_solver_block(
logger->debug("SympySolverVisitor :: -> adding statement: {}", sol);
}

std::vector<std::string> pre_solve_statements_and_setup_x_eqs(pre_solve_statements);
std::vector<std::string> pre_solve_statements_and_setup_x_eqs = pre_solve_statements;
std::vector<std::string> update_statements;

for (int i = 0; i < state_vars.size(); i++) {
auto update_state = state_vars[i] + " = nmodl_eigen_x[" + std::to_string(i) + "]";
auto setup_x = "nmodl_eigen_x[" + std::to_string(i) + "] = " + state_vars[i];
auto eigen_name = fmt::format("nmodl_eigen_x[{}]", i);

pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
auto update_state = fmt::format("{} = {}", state_vars[i], eigen_name);
update_statements.push_back(update_state);
logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x);
logger->debug("SympySolverVisitor :: update_state: {}", update_state);

auto setup_x = fmt::format("{} = {}", eigen_name, state_vars[i]);
pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x);
}

visitor::SympyReplaceSolutionsVisitor solution_replacer(
Expand Down Expand Up @@ -304,9 +308,7 @@ void SympySolverVisitor::construct_eigen_solver_block(


void SympySolverVisitor::solve_linear_system(const ast::Node& node,
const std::vector<std::string>& pre_solve_statements

) {
const std::vector<std::string>& pre_solve_statements) {
// construct ordered vector of state vars used in linear system
init_state_vars_vector(&node);
// call sympy linear solver
Expand Down Expand Up @@ -373,6 +375,7 @@ void SympySolverVisitor::solve_non_linear_system(
return;
}
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");

construct_eigen_solver_block(pre_solve_statements, solutions, false);
}

Expand Down
30 changes: 22 additions & 8 deletions test/unit/newton/newton.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("1 linear eq") {
struct functor {
void operator()(const Eigen::Matrix<double, 1, 1>& X,
Eigen::Matrix<double, 1, 1>& /* dX */,
Eigen::Matrix<double, 1, 1>& F,
Eigen::Matrix<double, 1, 1>& J) const {
// Function F(X) to find F(X)=0 solution
Expand All @@ -30,15 +31,16 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
}
};
Eigen::Matrix<double, 1, 1> X{22.2536};
Eigen::Matrix<double, 1, 1> dX;
Eigen::Matrix<double, 1, 1> F;
Eigen::Matrix<double, 1, 1> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find the solution") {
CAPTURE(iter_newton);
CAPTURE(X);
REQUIRE(iter_newton > 0);
REQUIRE(iter_newton == 1);
REQUIRE_THAT(X[0], Catch::Matchers::WithinRel(1.0, 0.01));
REQUIRE(F.norm() < max_error_norm);
}
Expand All @@ -47,18 +49,20 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("1 non-linear eq") {
struct functor {
void operator()(const Eigen::Matrix<double, 1, 1>& X,
Eigen::Matrix<double, 1, 1>& /* dX */,
Eigen::Matrix<double, 1, 1>& F,
Eigen::Matrix<double, 1, 1>& J) const {
F[0] = -3.0 * X[0] + std::sin(X[0]) + std::log(X[0] * X[0] + 11.435243) + 3.0;
J(0, 0) = -3.0 + std::cos(X[0]) + 2.0 * X[0] / (X[0] * X[0] + 11.435243);
}
};
Eigen::Matrix<double, 1, 1> X{-0.21421};
Eigen::Matrix<double, 1, 1> dX;
Eigen::Matrix<double, 1, 1> F;
Eigen::Matrix<double, 1, 1> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find the solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -71,6 +75,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("system of 2 non-linear eqs") {
struct functor {
void operator()(const Eigen::Matrix<double, 2, 1>& X,
Eigen::Matrix<double, 2, 1>& /* dX */,
Eigen::Matrix<double, 2, 1>& F,
Eigen::Matrix<double, 2, 2>& J) const {
F[0] = -3.0 * X[0] * X[1] + X[0] + 2.0 * X[1] - 1.0;
Expand All @@ -82,11 +87,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
}
};
Eigen::Matrix<double, 2, 1> X{0.2, 0.4};
Eigen::Matrix<double, 2, 1> dX;
Eigen::Matrix<double, 2, 1> F;
Eigen::Matrix<double, 2, 2> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -107,6 +113,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
double e = 0.01;
double z = 0.99;
void operator()(const Eigen::Matrix<double, 3, 1>& X,
Eigen::Matrix<double, 3, 1>& /* dX */,
Eigen::Matrix<double, 3, 1>& F,
Eigen::Matrix<double, 3, 3>& J) const {
F(0) = -(-_x_old - dt * (a * std::pow(X[0], 2) + X[1]) + X[0]);
Expand All @@ -124,11 +131,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
}
};
Eigen::Matrix<double, 3, 1> X{0.21231, 0.4435, -0.11537};
Eigen::Matrix<double, 3, 1> dX;
Eigen::Matrix<double, 3, 1> F;
Eigen::Matrix<double, 3, 3> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -145,6 +153,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
double X3_old = 1.2345;
double dt = 0.2;
void operator()(const Eigen::Matrix<double, 4, 1>& X,
Eigen::Matrix<double, 4, 1>& /* dX */,
Eigen::Matrix<double, 4, 1>& F,
Eigen::Matrix<double, 4, 4>& J) const {
F[0] = -(-3.0 * X[0] * X[2] * dt + X[0] - X0_old + 2.0 * dt / X[1]);
Expand All @@ -170,11 +179,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
}
};
Eigen::Matrix<double, 4, 1> X{0.21231, 0.4435, -0.11537, -0.8124312};
Eigen::Matrix<double, 4, 1> dX;
Eigen::Matrix<double, 4, 1> F;
Eigen::Matrix<double, 4, 4> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -186,6 +196,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("system of 5 non-linear eqs") {
struct functor {
void operator()(const Eigen::Matrix<double, 5, 1>& X,
Eigen::Matrix<double, 5, 1>& /* dX */,
Eigen::Matrix<double, 5, 1>& F,
Eigen::Matrix<double, 5, 5>& J) const {
F[0] = -3.0 * X[0] * X[2] + X[0] + 2.0 / X[1];
Expand Down Expand Up @@ -224,11 +235,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
};
Eigen::Matrix<double, 5, 1> X;
X << 8.234, -245.46, 123.123, 0.8343, 5.555;
Eigen::Matrix<double, 5, 1> dX;
Eigen::Matrix<double, 5, 1> F;
Eigen::Matrix<double, 5, 5> J;
functor fn;
int iter_newton = newton::newton_solver<5, functor>(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -240,6 +252,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("system of 10 non-linear eqs") {
struct functor {
void operator()(const Eigen::Matrix<double, 10, 1>& X,
Eigen::Matrix<double, 10, 1>& /* dX */,
Eigen::Matrix<double, 10, 1>& F,
Eigen::Matrix<double, 10, 10>& J) const {
F[0] = -3.0 * X[0] * X[1] + X[0] + 2.0 * X[1];
Expand Down Expand Up @@ -360,11 +373,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")

Eigen::Matrix<double, 10, 1> X;
X << 8.234, -5.46, 1.123, 0.8343, 5.555, 18.234, -2.46, 0.123, 10.8343, -4.685;
Eigen::Matrix<double, 10, 1> dX;
Eigen::Matrix<double, 10, 1> F;
Eigen::Matrix<double, 10, 10> J;
functor fn;
int iter_newton = newton::newton_solver<10, functor>(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand Down

0 comments on commit 29e6b6c

Please sign in to comment.