From 29e6b6c2864b91e07b96ae69d9b72cf2da34daf5 Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Tue, 17 Sep 2024 17:11:48 +0200 Subject: [PATCH 1/3] Allow `derivimplicit` to use finite differences. In NOCMODL `derivimplicit` uses finite differences to compute each element of the Jacobian. In stead we try to use SymPy, however, if it fails, e.g. because it encounters an opaque function, we allow it to use a finite difference instead. --- python/nmodl/ode.py | 68 +++++++++++++++++++++++---- src/codegen/codegen_cpp_visitor.cpp | 8 ++++ src/pybind/wrapper.cpp | 7 ++- src/solver/newton/newton.hpp | 7 ++- src/visitors/sympy_solver_visitor.cpp | 19 ++++---- test/unit/newton/newton.cpp | 30 ++++++++---- 6 files changed, 109 insertions(+), 30 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index e5cc926d1..5e87888cf 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -344,6 +344,55 @@ def solve_lin_system( return code, new_local_vars +def finite_difference_step_variable(sym): + return f"{sym}_delta_" + + +def discretize_derivative(expr): + if isinstance(expr, sp.Derivative): + x = expr.args[1][0] + dx = sp.symbols(finite_difference_step_variable(x)) + return expr.as_finite_difference(dx) + else: + return expr + + +def transform_expression(expr, transform): + if expr.args is tuple(): + return expr + + args = list(transform_expression(transform(arg), transform) for arg in expr.args) + return expr.func(*args) + + +def transform_matrix_elements(mat, transform): + return sp.Matrix( + [ + [transform_expression(mat[i, j], transform) for j in range(mat.rows)] + for i in range(mat.cols) + ] + ) + + +def finite_difference_variables(mat): + vars = [] + + def recurse(expr): + for arg in expr.args: + if isinstance(arg, sp.Derivative): + var = arg.args[1][0] + vars.append((var, finite_difference_step_variable(var))) + + for expr in mat: + recurse(expr) + + return vars + + +def needs_finite_differences(mat): + return any(isinstance(expr, sp.Derivative) for expr in sp.preorder_traversal(mat)) + + def solve_non_lin_system(eq_strings, vars, constants, function_calls): """Solve non-linear system of equations, return solution as C code. @@ -369,28 +418,31 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls): custom_fcts = _get_custom_functions(function_calls) jacobian = sp.Matrix(eqs).jacobian(state_vars) + if needs_finite_differences(jacobian): + jacobian = transform_matrix_elements(jacobian, discretize_derivative) X_vec_map = {x: sp.symbols(f"X[{i}]") for i, x in enumerate(state_vars)} + dX_vec_map = { + finite_difference_step_variable(x): sp.symbols(f"dX_[{i}]") + for i, x in enumerate(state_vars) + } vecFcode = [] for i, eq in enumerate(eqs): - vecFcode.append( - f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}" - ) + expr = eq.simplify().subs(X_vec_map).evalf() + rhs = sp.ccode(expr, user_functions=custom_fcts) + vecFcode.append(f"F[{i}] = {rhs}") vecJcode = [] for i, j in itertools.product(range(jacobian.rows), range(jacobian.cols)): flat_index = i + jacobian.rows * j - rhs = sp.ccode( - jacobian[i, j].simplify().subs(X_vec_map).evalf(), - user_functions=custom_fcts, - ) + Jij = jacobian[i, j].simplify().subs({**X_vec_map, **dX_vec_map}).evalf() + rhs = sp.ccode(Jij, user_functions=custom_fcts) vecJcode.append(f"J[{flat_index}] = {rhs}") # interweave code = _interweave_eqs(vecFcode, vecJcode) - code = search_and_replace_protected_identifiers_from_sympy(code, function_calls) return code diff --git a/src/codegen/codegen_cpp_visitor.cpp b/src/codegen/codegen_cpp_visitor.cpp index 09bfbea85..c07a2e49e 100644 --- a/src/codegen/codegen_cpp_visitor.cpp +++ b/src/codegen/codegen_cpp_visitor.cpp @@ -708,6 +708,7 @@ void CodegenCppVisitor::print_functor_definition(const ast::EigenNewtonSolverBlo printer->fmt_text( "void operator()(const Eigen::Matrix<{0}, {1}, 1>& nmodl_eigen_xm, Eigen::Matrix<{0}, {1}, " + "1>& nmodl_eigen_dxm, Eigen::Matrix<{0}, {1}, " "1>& nmodl_eigen_fm, " "Eigen::Matrix<{0}, {1}, {1}>& nmodl_eigen_jm) {2}", float_type, @@ -715,8 +716,15 @@ void CodegenCppVisitor::print_functor_definition(const ast::EigenNewtonSolverBlo is_functor_const(variable_block, functor_block) ? "const " : ""); printer->push_block(); printer->fmt_line("const {}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type); + printer->fmt_line("{}* nmodl_eigen_dx = nmodl_eigen_dxm.data();", float_type); printer->fmt_line("{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type); printer->fmt_line("{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type); + + for (size_t i = 0; i < N; ++i) { + printer->fmt_line( + "nmodl_eigen_dx[{0}] = std::max(1e-6, 0.02*std::fabs(nmodl_eigen_x[{0}]));", i); + } + print_statement_block(functor_block, false, false); printer->pop_block(); printer->add_newline(); diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 92f679a6a..32c390736 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -82,14 +82,13 @@ std::tuple, std::string> call_solve_nonlinear_system( exception_message = "" try: solutions = solve_non_lin_system(equation_strings, - state_vars, - vars, - function_calls) + state_vars, + vars, + function_calls) except Exception as e: # if we fail, fail silently and return empty string import traceback solutions = [""] - new_local_vars = [""] exception_message = traceback.format_exc() )"; diff --git a/src/solver/newton/newton.hpp b/src/solver/newton/newton.hpp index bd627d0db..5d4e05182 100644 --- a/src/solver/newton/newton.hpp +++ b/src/solver/newton/newton.hpp @@ -81,6 +81,8 @@ EIGEN_DEVICE_FUNC int newton_solver(Eigen::Matrix& X, FUNC functor, double eps = EPS, int max_iter = MAX_ITER) { + // If finite differences are needed, this is stores the stepwidth. + Eigen::Matrix dX; // Vector to store result of function F(X): Eigen::Matrix F; // Matrix to store Jacobian of F(X): @@ -89,7 +91,7 @@ EIGEN_DEVICE_FUNC int newton_solver(Eigen::Matrix& X, int iter = -1; while (++iter < max_iter) { // calculate F, J from X using user-supplied functor - functor(X, F, J); + functor(X, dX, F, J); if (is_converged(X, J, F, eps)) { return iter; } @@ -127,10 +129,11 @@ EIGEN_DEVICE_FUNC int newton_solver_small_N(Eigen::Matrix& X, int max_iter) { bool invertible; Eigen::Matrix F; + Eigen::Matrix dX; Eigen::Matrix J, J_inv; int iter = -1; while (++iter < max_iter) { - functor(X, F, J); + functor(X, dX, F, J); if (is_converged(X, J, F, eps)) { return iter; } diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index 8e3c8fdc1..f2d6260c2 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -179,6 +179,7 @@ void SympySolverVisitor::construct_eigen_solver_block( const std::vector& solutions, bool linear) { auto solutions_filtered = filter_string_vector(solutions, "X[", "nmodl_eigen_x["); + solutions_filtered = filter_string_vector(solutions_filtered, "dX_[", "nmodl_eigen_dx["); solutions_filtered = filter_string_vector(solutions_filtered, "J[", "nmodl_eigen_j["); solutions_filtered = filter_string_vector(solutions_filtered, "Jm[", "nmodl_eigen_jm["); solutions_filtered = filter_string_vector(solutions_filtered, "F[", "nmodl_eigen_f["); @@ -187,16 +188,19 @@ void SympySolverVisitor::construct_eigen_solver_block( logger->debug("SympySolverVisitor :: -> adding statement: {}", sol); } - std::vector pre_solve_statements_and_setup_x_eqs(pre_solve_statements); + std::vector pre_solve_statements_and_setup_x_eqs = pre_solve_statements; std::vector update_statements; + for (int i = 0; i < state_vars.size(); i++) { - auto update_state = state_vars[i] + " = nmodl_eigen_x[" + std::to_string(i) + "]"; - auto setup_x = "nmodl_eigen_x[" + std::to_string(i) + "] = " + state_vars[i]; + auto eigen_name = fmt::format("nmodl_eigen_x[{}]", i); - pre_solve_statements_and_setup_x_eqs.push_back(setup_x); + auto update_state = fmt::format("{} = {}", state_vars[i], eigen_name); update_statements.push_back(update_state); - logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x); logger->debug("SympySolverVisitor :: update_state: {}", update_state); + + auto setup_x = fmt::format("{} = {}", eigen_name, state_vars[i]); + pre_solve_statements_and_setup_x_eqs.push_back(setup_x); + logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x); } visitor::SympyReplaceSolutionsVisitor solution_replacer( @@ -304,9 +308,7 @@ void SympySolverVisitor::construct_eigen_solver_block( void SympySolverVisitor::solve_linear_system(const ast::Node& node, - const std::vector& pre_solve_statements - -) { + const std::vector& pre_solve_statements) { // construct ordered vector of state vars used in linear system init_state_vars_vector(&node); // call sympy linear solver @@ -373,6 +375,7 @@ void SympySolverVisitor::solve_non_linear_system( return; } logger->debug("SympySolverVisitor :: Constructing eigen newton solve block"); + construct_eigen_solver_block(pre_solve_statements, solutions, false); } diff --git a/test/unit/newton/newton.cpp b/test/unit/newton/newton.cpp index 7ed95df89..cd2de7d3f 100644 --- a/test/unit/newton/newton.cpp +++ b/test/unit/newton/newton.cpp @@ -21,6 +21,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") GIVEN("1 linear eq") { struct functor { void operator()(const Eigen::Matrix& X, + Eigen::Matrix& /* dX */, Eigen::Matrix& F, Eigen::Matrix& J) const { // Function F(X) to find F(X)=0 solution @@ -30,15 +31,16 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") } }; Eigen::Matrix X{22.2536}; + Eigen::Matrix dX; Eigen::Matrix F; Eigen::Matrix J; functor fn; int iter_newton = newton::newton_solver(X, fn); - fn(X, F, J); + fn(X, dX, F, J); THEN("find the solution") { CAPTURE(iter_newton); CAPTURE(X); - REQUIRE(iter_newton > 0); + REQUIRE(iter_newton == 1); REQUIRE_THAT(X[0], Catch::Matchers::WithinRel(1.0, 0.01)); REQUIRE(F.norm() < max_error_norm); } @@ -47,6 +49,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") GIVEN("1 non-linear eq") { struct functor { void operator()(const Eigen::Matrix& X, + Eigen::Matrix& /* dX */, Eigen::Matrix& F, Eigen::Matrix& J) const { F[0] = -3.0 * X[0] + std::sin(X[0]) + std::log(X[0] * X[0] + 11.435243) + 3.0; @@ -54,11 +57,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") } }; Eigen::Matrix X{-0.21421}; + Eigen::Matrix dX; Eigen::Matrix F; Eigen::Matrix J; functor fn; int iter_newton = newton::newton_solver(X, fn); - fn(X, F, J); + fn(X, dX, F, J); THEN("find the solution") { CAPTURE(iter_newton); CAPTURE(X); @@ -71,6 +75,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") GIVEN("system of 2 non-linear eqs") { struct functor { void operator()(const Eigen::Matrix& X, + Eigen::Matrix& /* dX */, Eigen::Matrix& F, Eigen::Matrix& J) const { F[0] = -3.0 * X[0] * X[1] + X[0] + 2.0 * X[1] - 1.0; @@ -82,11 +87,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") } }; Eigen::Matrix X{0.2, 0.4}; + Eigen::Matrix dX; Eigen::Matrix F; Eigen::Matrix J; functor fn; int iter_newton = newton::newton_solver(X, fn); - fn(X, F, J); + fn(X, dX, F, J); THEN("find a solution") { CAPTURE(iter_newton); CAPTURE(X); @@ -107,6 +113,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") double e = 0.01; double z = 0.99; void operator()(const Eigen::Matrix& X, + Eigen::Matrix& /* dX */, Eigen::Matrix& F, Eigen::Matrix& J) const { F(0) = -(-_x_old - dt * (a * std::pow(X[0], 2) + X[1]) + X[0]); @@ -124,11 +131,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") } }; Eigen::Matrix X{0.21231, 0.4435, -0.11537}; + Eigen::Matrix dX; Eigen::Matrix F; Eigen::Matrix J; functor fn; int iter_newton = newton::newton_solver(X, fn); - fn(X, F, J); + fn(X, dX, F, J); THEN("find a solution") { CAPTURE(iter_newton); CAPTURE(X); @@ -145,6 +153,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") double X3_old = 1.2345; double dt = 0.2; void operator()(const Eigen::Matrix& X, + Eigen::Matrix& /* dX */, Eigen::Matrix& F, Eigen::Matrix& J) const { F[0] = -(-3.0 * X[0] * X[2] * dt + X[0] - X0_old + 2.0 * dt / X[1]); @@ -170,11 +179,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") } }; Eigen::Matrix X{0.21231, 0.4435, -0.11537, -0.8124312}; + Eigen::Matrix dX; Eigen::Matrix F; Eigen::Matrix J; functor fn; int iter_newton = newton::newton_solver(X, fn); - fn(X, F, J); + fn(X, dX, F, J); THEN("find a solution") { CAPTURE(iter_newton); CAPTURE(X); @@ -186,6 +196,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") GIVEN("system of 5 non-linear eqs") { struct functor { void operator()(const Eigen::Matrix& X, + Eigen::Matrix& /* dX */, Eigen::Matrix& F, Eigen::Matrix& J) const { F[0] = -3.0 * X[0] * X[2] + X[0] + 2.0 / X[1]; @@ -224,11 +235,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") }; Eigen::Matrix X; X << 8.234, -245.46, 123.123, 0.8343, 5.555; + Eigen::Matrix dX; Eigen::Matrix F; Eigen::Matrix J; functor fn; int iter_newton = newton::newton_solver<5, functor>(X, fn); - fn(X, F, J); + fn(X, dX, F, J); THEN("find a solution") { CAPTURE(iter_newton); CAPTURE(X); @@ -240,6 +252,7 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") GIVEN("system of 10 non-linear eqs") { struct functor { void operator()(const Eigen::Matrix& X, + Eigen::Matrix& /* dX */, Eigen::Matrix& F, Eigen::Matrix& J) const { F[0] = -3.0 * X[0] * X[1] + X[0] + 2.0 * X[1]; @@ -360,11 +373,12 @@ SCENARIO("Non-linear system to solve with Newton Solver", "[analytic][solver]") Eigen::Matrix X; X << 8.234, -5.46, 1.123, 0.8343, 5.555, 18.234, -2.46, 0.123, 10.8343, -4.685; + Eigen::Matrix dX; Eigen::Matrix F; Eigen::Matrix J; functor fn; int iter_newton = newton::newton_solver<10, functor>(X, fn); - fn(X, F, J); + fn(X, dX, F, J); THEN("find a solution") { CAPTURE(iter_newton); CAPTURE(X); From 7b9225c5bcd190606b8e164b574f2614781469e8 Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Fri, 20 Sep 2024 09:44:28 +0200 Subject: [PATCH 2/3] Add test. --- test/usecases/solve/finite_difference.mod | 30 +++++++++++++++++++ test/usecases/solve/test_finite_difference.py | 30 +++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 test/usecases/solve/finite_difference.mod create mode 100644 test/usecases/solve/test_finite_difference.py diff --git a/test/usecases/solve/finite_difference.mod b/test/usecases/solve/finite_difference.mod new file mode 100644 index 000000000..2c0f94e86 --- /dev/null +++ b/test/usecases/solve/finite_difference.mod @@ -0,0 +1,30 @@ +NEURON { + SUFFIX finite_difference + GLOBAL a + THREADSAFE +} + +ASSIGNED { + a +} + +STATE { + x +} + +INITIAL { + x = 42.0 + a = 0.1 +} + +BREAKPOINT { + SOLVE dX METHOD derivimplicit +} + +DERIVATIVE dX { + x' = -f(x) +} + +FUNCTION f(x) { + f = a*x +} diff --git a/test/usecases/solve/test_finite_difference.py b/test/usecases/solve/test_finite_difference.py new file mode 100644 index 000000000..ba204ff2f --- /dev/null +++ b/test/usecases/solve/test_finite_difference.py @@ -0,0 +1,30 @@ +import numpy as np +from neuron import h, gui +from neuron.units import ms + + +def test_finite_difference(): + nseg = 1 + + s = h.Section() + s.insert("finite_difference") + s.nseg = nseg + + x_hoc = h.Vector().record(getattr(s(0.5), f"_ref_x_finite_difference")) + t_hoc = h.Vector().record(h._ref_t) + + h.stdinit() + h.dt = 0.001 + h.tstop = 5.0 * ms + h.run() + + x = np.array(x_hoc.as_numpy()) + t = np.array(t_hoc.as_numpy()) + + a = h.a_finite_difference + x_exact = 42.0 * np.exp(-a * t) + np.testing.assert_allclose(x, x_exact, rtol=1e-4) + + +if __name__ == "__main__": + test_finite_difference() From 5576cc965e5f4cbe5c9020e84f8f84b5a8fb55a8 Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Fri, 20 Sep 2024 16:43:33 +0200 Subject: [PATCH 3/3] Fixups. --- python/nmodl/ode.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 5e87888cf..dd6f7e068 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -358,10 +358,10 @@ def discretize_derivative(expr): def transform_expression(expr, transform): - if expr.args is tuple(): + if len(expr.args) == 0: return expr - args = list(transform_expression(transform(arg), transform) for arg in expr.args) + args = (transform_expression(transform(arg), transform) for arg in expr.args) return expr.func(*args) @@ -374,21 +374,6 @@ def transform_matrix_elements(mat, transform): ) -def finite_difference_variables(mat): - vars = [] - - def recurse(expr): - for arg in expr.args: - if isinstance(arg, sp.Derivative): - var = arg.args[1][0] - vars.append((var, finite_difference_step_variable(var))) - - for expr in mat: - recurse(expr) - - return vars - - def needs_finite_differences(mat): return any(isinstance(expr, sp.Derivative) for expr in sp.preorder_traversal(mat))