Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow derivimplicit to use finite differences. #1444

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,40 @@ def solve_lin_system(
return code, new_local_vars


def finite_difference_step_variable(sym):
return f"{sym}_delta_"


def discretize_derivative(expr):
if isinstance(expr, sp.Derivative):
x = expr.args[1][0]
dx = sp.symbols(finite_difference_step_variable(x))
return expr.as_finite_difference(dx)
else:
return expr


def transform_expression(expr, transform):
if not expr.args:
return expr

args = (transform_expression(transform(arg), transform) for arg in expr.args)
return expr.func(*args)


def transform_matrix_elements(mat, transform):
return sp.Matrix(
[
[transform_expression(mat[i, j], transform) for j in range(mat.rows)]
for i in range(mat.cols)
]
)


def needs_finite_differences(mat):
return any(isinstance(expr, sp.Derivative) for expr in sp.preorder_traversal(mat))


def solve_non_lin_system(eq_strings, vars, constants, function_calls):
"""Solve non-linear system of equations, return solution as C code.

Expand All @@ -369,28 +403,31 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
custom_fcts = _get_custom_functions(function_calls)

jacobian = sp.Matrix(eqs).jacobian(state_vars)
if needs_finite_differences(jacobian):
jacobian = transform_matrix_elements(jacobian, discretize_derivative)

X_vec_map = {x: sp.symbols(f"X[{i}]") for i, x in enumerate(state_vars)}
dX_vec_map = {
finite_difference_step_variable(x): sp.symbols(f"dX_[{i}]")
for i, x in enumerate(state_vars)
}

vecFcode = []
for i, eq in enumerate(eqs):
vecFcode.append(
f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}"
)
expr = eq.simplify().subs(X_vec_map).evalf()
rhs = sp.ccode(expr, user_functions=custom_fcts)
vecFcode.append(f"F[{i}] = {rhs}")

vecJcode = []
for i, j in itertools.product(range(jacobian.rows), range(jacobian.cols)):
flat_index = i + jacobian.rows * j

rhs = sp.ccode(
jacobian[i, j].simplify().subs(X_vec_map).evalf(),
user_functions=custom_fcts,
)
Jij = jacobian[i, j].simplify().subs({**X_vec_map, **dX_vec_map}).evalf()
rhs = sp.ccode(Jij, user_functions=custom_fcts)
vecJcode.append(f"J[{flat_index}] = {rhs}")

# interweave
code = _interweave_eqs(vecFcode, vecJcode)

code = search_and_replace_protected_identifiers_from_sympy(code, function_calls)

return code
Expand Down
8 changes: 8 additions & 0 deletions src/codegen/codegen_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,15 +708,23 @@

printer->fmt_text(
"void operator()(const Eigen::Matrix<{0}, {1}, 1>& nmodl_eigen_xm, Eigen::Matrix<{0}, {1}, "
"1>& nmodl_eigen_dxm, Eigen::Matrix<{0}, {1}, "
"1>& nmodl_eigen_fm, "
"Eigen::Matrix<{0}, {1}, {1}>& nmodl_eigen_jm) {2}",
float_type,
N,
is_functor_const(variable_block, functor_block) ? "const " : "");
printer->push_block();
printer->fmt_line("const {}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
printer->fmt_line("{}* nmodl_eigen_dx = nmodl_eigen_dxm.data();", float_type);
printer->fmt_line("{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type);
printer->fmt_line("{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type);

for (size_t i = 0; i < N; ++i) {
printer->fmt_line(
"nmodl_eigen_dx[{0}] = std::max(1e-6, 0.02*std::fabs(nmodl_eigen_x[{0}]));", i);
1uc marked this conversation as resolved.
Show resolved Hide resolved
}

print_statement_block(functor_block, false, false);
printer->pop_block();
printer->add_newline();
Expand Down Expand Up @@ -1665,7 +1673,7 @@
auto rhs = get_variable_name(state_name + "0");

if (state->is_array()) {
auto size = state->get_length();

Check warning on line 1676 in src/codegen/codegen_cpp_visitor.cpp

View workflow job for this annotation

GitHub Actions / { "flag_warnings": "ON", "glibc_asserts": "ON", "os": "ubuntu-22.04" }

unused variable ‘size’ [-Wunused-variable]

Check warning on line 1676 in src/codegen/codegen_cpp_visitor.cpp

View workflow job for this annotation

GitHub Actions / { "flag_warnings": "ON", "os": "ubuntu-22.04", "sanitizer": "undefined" }

unused variable 'size' [-Wunused-variable]
for (int i = 0; i < state->get_length(); ++i) {
printer->fmt_line("{}[{}] = {};", lhs, i, rhs);
}
Expand Down
7 changes: 3 additions & 4 deletions src/pybind/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@ std::tuple<std::vector<std::string>, std::string> call_solve_nonlinear_system(
exception_message = ""
try:
solutions = solve_non_lin_system(equation_strings,
state_vars,
vars,
function_calls)
state_vars,
vars,
function_calls)
except Exception as e:
# if we fail, fail silently and return empty string
import traceback
solutions = [""]
new_local_vars = [""]
exception_message = traceback.format_exc()
)";

Expand Down
7 changes: 5 additions & 2 deletions src/solver/newton/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ EIGEN_DEVICE_FUNC int newton_solver(Eigen::Matrix<double, N, 1>& X,
FUNC functor,
double eps = EPS,
int max_iter = MAX_ITER) {
// If finite differences are needed, this is stores the stepwidth.
Eigen::Matrix<double, N, 1> dX;
// Vector to store result of function F(X):
Eigen::Matrix<double, N, 1> F;
// Matrix to store Jacobian of F(X):
Expand All @@ -89,7 +91,7 @@ EIGEN_DEVICE_FUNC int newton_solver(Eigen::Matrix<double, N, 1>& X,
int iter = -1;
while (++iter < max_iter) {
// calculate F, J from X using user-supplied functor
functor(X, F, J);
functor(X, dX, F, J);
if (is_converged(X, J, F, eps)) {
return iter;
}
Expand Down Expand Up @@ -127,10 +129,11 @@ EIGEN_DEVICE_FUNC int newton_solver_small_N(Eigen::Matrix<double, N, 1>& X,
int max_iter) {
bool invertible;
Eigen::Matrix<double, N, 1> F;
Eigen::Matrix<double, N, 1> dX;
Eigen::Matrix<double, N, N> J, J_inv;
int iter = -1;
while (++iter < max_iter) {
functor(X, F, J);
functor(X, dX, F, J);
if (is_converged(X, J, F, eps)) {
return iter;
}
Expand Down
19 changes: 11 additions & 8 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ void SympySolverVisitor::construct_eigen_solver_block(
const std::vector<std::string>& solutions,
bool linear) {
auto solutions_filtered = filter_string_vector(solutions, "X[", "nmodl_eigen_x[");
solutions_filtered = filter_string_vector(solutions_filtered, "dX_[", "nmodl_eigen_dx[");
solutions_filtered = filter_string_vector(solutions_filtered, "J[", "nmodl_eigen_j[");
solutions_filtered = filter_string_vector(solutions_filtered, "Jm[", "nmodl_eigen_jm[");
solutions_filtered = filter_string_vector(solutions_filtered, "F[", "nmodl_eigen_f[");
Expand All @@ -187,16 +188,19 @@ void SympySolverVisitor::construct_eigen_solver_block(
logger->debug("SympySolverVisitor :: -> adding statement: {}", sol);
}

std::vector<std::string> pre_solve_statements_and_setup_x_eqs(pre_solve_statements);
std::vector<std::string> pre_solve_statements_and_setup_x_eqs = pre_solve_statements;
std::vector<std::string> update_statements;

for (int i = 0; i < state_vars.size(); i++) {
auto update_state = state_vars[i] + " = nmodl_eigen_x[" + std::to_string(i) + "]";
auto setup_x = "nmodl_eigen_x[" + std::to_string(i) + "] = " + state_vars[i];
auto eigen_name = fmt::format("nmodl_eigen_x[{}]", i);

pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
auto update_state = fmt::format("{} = {}", state_vars[i], eigen_name);
update_statements.push_back(update_state);
logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x);
logger->debug("SympySolverVisitor :: update_state: {}", update_state);

auto setup_x = fmt::format("{} = {}", eigen_name, state_vars[i]);
pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x);
}

visitor::SympyReplaceSolutionsVisitor solution_replacer(
Expand Down Expand Up @@ -304,9 +308,7 @@ void SympySolverVisitor::construct_eigen_solver_block(


void SympySolverVisitor::solve_linear_system(const ast::Node& node,
const std::vector<std::string>& pre_solve_statements

) {
const std::vector<std::string>& pre_solve_statements) {
// construct ordered vector of state vars used in linear system
init_state_vars_vector(&node);
// call sympy linear solver
Expand Down Expand Up @@ -373,6 +375,7 @@ void SympySolverVisitor::solve_non_linear_system(
return;
}
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");

construct_eigen_solver_block(pre_solve_statements, solutions, false);
}

Expand Down
30 changes: 22 additions & 8 deletions test/unit/newton/newton.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("1 linear eq") {
struct functor {
void operator()(const Eigen::Matrix<double, 1, 1>& X,
Eigen::Matrix<double, 1, 1>& /* dX */,
Eigen::Matrix<double, 1, 1>& F,
Eigen::Matrix<double, 1, 1>& J) const {
// Function F(X) to find F(X)=0 solution
Expand All @@ -30,15 +31,16 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
}
};
Eigen::Matrix<double, 1, 1> X{22.2536};
Eigen::Matrix<double, 1, 1> dX;
Eigen::Matrix<double, 1, 1> F;
Eigen::Matrix<double, 1, 1> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find the solution") {
CAPTURE(iter_newton);
CAPTURE(X);
REQUIRE(iter_newton > 0);
REQUIRE(iter_newton == 1);
REQUIRE_THAT(X[0], Catch::Matchers::WithinRel(1.0, 0.01));
REQUIRE(F.norm() < max_error_norm);
}
Expand All @@ -47,18 +49,20 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("1 non-linear eq") {
struct functor {
void operator()(const Eigen::Matrix<double, 1, 1>& X,
Eigen::Matrix<double, 1, 1>& /* dX */,
Eigen::Matrix<double, 1, 1>& F,
Eigen::Matrix<double, 1, 1>& J) const {
F[0] = -3.0 * X[0] + std::sin(X[0]) + std::log(X[0] * X[0] + 11.435243) + 3.0;
J(0, 0) = -3.0 + std::cos(X[0]) + 2.0 * X[0] / (X[0] * X[0] + 11.435243);
}
};
Eigen::Matrix<double, 1, 1> X{-0.21421};
Eigen::Matrix<double, 1, 1> dX;
Eigen::Matrix<double, 1, 1> F;
Eigen::Matrix<double, 1, 1> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find the solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -71,6 +75,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("system of 2 non-linear eqs") {
struct functor {
void operator()(const Eigen::Matrix<double, 2, 1>& X,
Eigen::Matrix<double, 2, 1>& /* dX */,
Eigen::Matrix<double, 2, 1>& F,
Eigen::Matrix<double, 2, 2>& J) const {
F[0] = -3.0 * X[0] * X[1] + X[0] + 2.0 * X[1] - 1.0;
Expand All @@ -82,11 +87,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
}
};
Eigen::Matrix<double, 2, 1> X{0.2, 0.4};
Eigen::Matrix<double, 2, 1> dX;
Eigen::Matrix<double, 2, 1> F;
Eigen::Matrix<double, 2, 2> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -107,6 +113,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
double e = 0.01;
double z = 0.99;
void operator()(const Eigen::Matrix<double, 3, 1>& X,
Eigen::Matrix<double, 3, 1>& /* dX */,
Eigen::Matrix<double, 3, 1>& F,
Eigen::Matrix<double, 3, 3>& J) const {
F(0) = -(-_x_old - dt * (a * std::pow(X[0], 2) + X[1]) + X[0]);
Expand All @@ -124,11 +131,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
}
};
Eigen::Matrix<double, 3, 1> X{0.21231, 0.4435, -0.11537};
Eigen::Matrix<double, 3, 1> dX;
Eigen::Matrix<double, 3, 1> F;
Eigen::Matrix<double, 3, 3> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -145,6 +153,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
double X3_old = 1.2345;
double dt = 0.2;
void operator()(const Eigen::Matrix<double, 4, 1>& X,
Eigen::Matrix<double, 4, 1>& /* dX */,
Eigen::Matrix<double, 4, 1>& F,
Eigen::Matrix<double, 4, 4>& J) const {
F[0] = -(-3.0 * X[0] * X[2] * dt + X[0] - X0_old + 2.0 * dt / X[1]);
Expand All @@ -170,11 +179,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
}
};
Eigen::Matrix<double, 4, 1> X{0.21231, 0.4435, -0.11537, -0.8124312};
Eigen::Matrix<double, 4, 1> dX;
Eigen::Matrix<double, 4, 1> F;
Eigen::Matrix<double, 4, 4> J;
functor fn;
int iter_newton = newton::newton_solver(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -186,6 +196,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("system of 5 non-linear eqs") {
struct functor {
void operator()(const Eigen::Matrix<double, 5, 1>& X,
Eigen::Matrix<double, 5, 1>& /* dX */,
Eigen::Matrix<double, 5, 1>& F,
Eigen::Matrix<double, 5, 5>& J) const {
F[0] = -3.0 * X[0] * X[2] + X[0] + 2.0 / X[1];
Expand Down Expand Up @@ -224,11 +235,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
};
Eigen::Matrix<double, 5, 1> X;
X << 8.234, -245.46, 123.123, 0.8343, 5.555;
Eigen::Matrix<double, 5, 1> dX;
Eigen::Matrix<double, 5, 1> F;
Eigen::Matrix<double, 5, 5> J;
functor fn;
int iter_newton = newton::newton_solver<5, functor>(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand All @@ -240,6 +252,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")
GIVEN("system of 10 non-linear eqs") {
struct functor {
void operator()(const Eigen::Matrix<double, 10, 1>& X,
Eigen::Matrix<double, 10, 1>& /* dX */,
Eigen::Matrix<double, 10, 1>& F,
Eigen::Matrix<double, 10, 10>& J) const {
F[0] = -3.0 * X[0] * X[1] + X[0] + 2.0 * X[1];
Expand Down Expand Up @@ -360,11 +373,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]")

Eigen::Matrix<double, 10, 1> X;
X << 8.234, -5.46, 1.123, 0.8343, 5.555, 18.234, -2.46, 0.123, 10.8343, -4.685;
Eigen::Matrix<double, 10, 1> dX;
Eigen::Matrix<double, 10, 1> F;
Eigen::Matrix<double, 10, 10> J;
functor fn;
int iter_newton = newton::newton_solver<10, functor>(X, fn);
fn(X, F, J);
fn(X, dX, F, J);
THEN("find a solution") {
CAPTURE(iter_newton);
CAPTURE(X);
Expand Down
30 changes: 30 additions & 0 deletions test/usecases/solve/finite_difference.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
NEURON {
SUFFIX finite_difference
GLOBAL a
THREADSAFE
}

ASSIGNED {
a
}

STATE {
x
}

INITIAL {
x = 42.0
a = 0.1
}

BREAKPOINT {
SOLVE dX METHOD derivimplicit
}

DERIVATIVE dX {
x' = -f(x)
}

FUNCTION f(x) {
f = a*x
}
Loading
Loading