Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.experimental.enable_x64 and jit #523

Open
dv-ai opened this issue Nov 12, 2024 · 1 comment
Open

jax.experimental.enable_x64 and jit #523

dv-ai opened this issue Nov 12, 2024 · 1 comment
Labels
question User queries

Comments

@dv-ai
Copy link

dv-ai commented Nov 12, 2024

diffrax 0.6.0
jax 0.4.30

Using jax.experimental.enable_x64 and jit will return a exception.

import jax
import jax.numpy as jnp
import diffrax

key = jax.random.PRNGKey(0)
x1 = jax.random.uniform(key,(2,))
x2 = jax.random.uniform(key,(2,2))

def odes(t, y, w):
    result =  w @ y
    return result

# Define single solve function for fixed final time with scaling
def ode_fun(y0, w):
    with jax.experimental.enable_x64():
        y0 = y0.astype(jnp.float64)
        w = w.astype(jnp.float64)

        term = diffrax.ODETerm(odes)
        solver = diffrax.Dopri5()

        controler = diffrax.PIDController(atol=1E-8,rtol=1E-8)
        saveat = diffrax.SaveAt(t1=True)
        sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=1.0, y0=y0, saveat=saveat, stepsize_controller=controler, dt0=None, max_steps=None, args=w)  # Fixed t1=1.0
        result = sol.ys
        result = result.astype(jnp.float32)

    return result

print(ode_fun(x1,x2)) # working
ode_fun_jit = jax.jit(ode_fun)
print(ode_fun_jit(x1, x2)) # exception see below

Exception raised:

raceback (most recent call last):
  File "tmp.py", line 122, in <module>
    print(ode_fun(x1,x2))
  File "tmp.py", line 116, in ode_fun
    sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=1.0, y0=y0, saveat=saveat, stepsize_controller=controler, dt0=None, max_steps=None, args=w)  # Fixed t1=1.0
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 1337, in diffeqsolve
    final_state, aux_stats = adjoint.loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_adjoint.py", line 292, in loop
    final_state = self._loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 621, in loop
    final_state = outer_while_loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\loop.py", line 103, in while_loop
    _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\common.py", line 463, in new_body_fun
    buffer_val2 = body_fun(buffer_val)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 618, in body_fun
    new_state, _, _ = body_fun_aux(state)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 332, in body_fun_aux
    (y, y_error, dense_info, solver_state, solver_result) = solver.step(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 1151, in step
    final_val = eqxi.while_loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\loop.py", line 107, in while_loop
    return checkpointed_while_loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\checkpointed.py", line 247, in checkpointed_while_loop
    body_fun_ = filter_closure_convert(body_fun_, init_val_)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\common.py", line 463, in new_body_fun
    buffer_val2 = body_fun(buffer_val)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 855, in rk_stage
    a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 604, in t_map
    return jtu.tree_map(_fn, tableaus, *trees)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 602, in _fn
    return fn(*_trees)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 855, in <lambda>
    a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\array_methods.py", line 739, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\array_methods.py", line 352, in _getitem
    return lax_numpy._rewriting_take(self, item)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6579, in _rewriting_take
    if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6563, in _attempt_rewriting_take_via_slice
    arr = lax.dynamic_slice(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Operation creation failed

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "tmp.py", line 125, in <module>
    print(ode_fun_jit(x1, x2))
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jaxlib\mlir\dialects\_stablehlo_ops_gen.py", line 2487, in dynamic_slice
    return _get_op_result_or_op_results(DynamicSliceOp(operand=operand, start_indices=start_indices, slice_sizes=slice_sizes, loc=loc, ip=ip))
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jaxlib\mlir\dialects\_stablehlo_ops_gen.py", line 2461, in __init__
    super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
ValueError: Operation creation failed
@patrick-kidger
Copy link
Owner

AFAIK jax.experimental.enable_x64 is still fairly buggy/experimental, in part because of things like this. I don't think it's recommended for use.

@patrick-kidger patrick-kidger added the question User queries label Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants