From befae93ae84c04587e89fd46e16ac818e94fcad3 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Mon, 17 Jul 2023 12:25:51 -0400 Subject: [PATCH 1/2] Properly cache endpoints in InterpolatedUnivariateSpline --- jax_cosmo/scipy/interpolate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_cosmo/scipy/interpolate.py b/jax_cosmo/scipy/interpolate.py index 40ad80d..7b3278f 100644 --- a/jax_cosmo/scipy/interpolate.py +++ b/jax_cosmo/scipy/interpolate.py @@ -226,6 +226,7 @@ def __init__(self, x, y, k=3, endpoints="not-a-knot", coefficients=None): # Saving spline parameters for evaluation later self.k = k + self._endpoints = endpoints self._x = x self._y = y self._coefficients = coefficients From c93f63cbbed7a07a96eb333e2a11f109ed0e31fb Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Mon, 17 Jul 2023 12:26:09 -0400 Subject: [PATCH 2/2] Add unit test confirming that its possible to use InterpolatedUnivariateSpline on PyTrees --- tests/test_spline.py | 60 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/test_spline.py b/tests/test_spline.py index 725039d..dc41fc4 100644 --- a/tests/test_spline.py +++ b/tests/test_spline.py @@ -3,6 +3,7 @@ from jax.config import config config.update("jax_enable_x64", True) +import jax import jax.numpy as np from numpy.testing import assert_allclose @@ -72,3 +73,62 @@ def test_cubic_spline(): a = spl_ref.antiderivative()(t) - spl_ref.antiderivative()(0.01) b = spl.antiderivative(t) - spl.antiderivative(0.01) assert_allclose(a, b, rtol=1e-10) + + +def test_spline_pytree(): + """ + Test that we can interpolate over pytrees. + """ + + # Time and data structure to interpolate over. + ts = np.linspace(0, 1, 10) + us = { + "a": np.linspace(0.0, 1.0, 10), + "b": { + "b0": np.linspace(0.0, 0.1, 10), + "b1": np.linspace(0.0, 0.2, 10), + }, + } + + # Generate a pytree of splines with the same structure as "us". + spline_order = 1 + spline_tree = jax.tree_util.tree_map( + lambda u: InterpolatedUnivariateSpline(ts, u, spline_order), us + ) + + def eval_splines(t): + return jax.tree_util.tree_map( + lambda sp: sp(t), + spline_tree, + is_leaf=lambda obj: isinstance(obj, InterpolatedUnivariateSpline), + ) + + # Evaluate the splines at t=0.0. + out0 = eval_splines(0.0) + assert out0 == { + "a": 0.0, + "b": { + "b0": 0.0, + "b1": 0.0, + }, + } + + # Evaluate the splines at t=0.5. + out05 = eval_splines(0.5) + assert out05 == { + "a": 0.5, + "b": { + "b0": 0.05, + "b1": 0.1, + }, + } + + # Evaluate the splines at t=1.0. + out1 = eval_splines(1.0) + assert out1 == { + "a": 1.0, + "b": { + "b0": 0.1, + "b1": 0.2, + }, + }