You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using jax.experimental.enable_x64 and jit will return a exception.
importjaximportjax.numpyasjnpimportdiffraxkey=jax.random.PRNGKey(0)
x1=jax.random.uniform(key,(2,))
x2=jax.random.uniform(key,(2,2))
defodes(t, y, w):
result=w @ yreturnresult# Define single solve function for fixed final time with scalingdefode_fun(y0, w):
withjax.experimental.enable_x64():
y0=y0.astype(jnp.float64)
w=w.astype(jnp.float64)
term=diffrax.ODETerm(odes)
solver=diffrax.Dopri5()
controler=diffrax.PIDController(atol=1E-8,rtol=1E-8)
saveat=diffrax.SaveAt(t1=True)
sol=diffrax.diffeqsolve(term, solver, t0=0.0, t1=1.0, y0=y0, saveat=saveat, stepsize_controller=controler, dt0=None, max_steps=None, args=w) # Fixed t1=1.0result=sol.ysresult=result.astype(jnp.float32)
returnresultprint(ode_fun(x1,x2)) # workingode_fun_jit=jax.jit(ode_fun)
print(ode_fun_jit(x1, x2)) # exception see below
Exception raised:
raceback (most recent call last):
File "tmp.py", line 122, in <module>
print(ode_fun(x1,x2))
File "tmp.py", line 116, in ode_fun
sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=1.0, y0=y0, saveat=saveat, stepsize_controller=controler, dt0=None, max_steps=None, args=w) # Fixed t1=1.0
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 1337, in diffeqsolve
final_state, aux_stats = adjoint.loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_adjoint.py", line 292, in loop
final_state = self._loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 621, in loop
final_state = outer_while_loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\contextlib.py", line 79, in inner
return func(*args, **kwds)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\loop.py", line 103, in while_loop
_, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\common.py", line 463, in new_body_fun
buffer_val2 = body_fun(buffer_val)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 618, in body_fun
new_state, _, _ = body_fun_aux(state)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 332, in body_fun_aux
(y, y_error, dense_info, solver_state, solver_result) = solver.step(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 1151, in step
final_val = eqxi.while_loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\loop.py", line 107, in while_loop
return checkpointed_while_loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\checkpointed.py", line 247, in checkpointed_while_loop
body_fun_ = filter_closure_convert(body_fun_, init_val_)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\common.py", line 463, in new_body_fun
buffer_val2 = body_fun(buffer_val)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 855, in rk_stage
a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 604, in t_map
return jtu.tree_map(_fn, tableaus, *trees)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 602, in _fn
return fn(*_trees)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 855, in <lambda>
a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\array_methods.py", line 739, in op
return getattr(self.aval, f"_{name}")(self, *args)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\array_methods.py", line 352, in _getitem
return lax_numpy._rewriting_take(self, item)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6579, in _rewriting_take
if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6563, in _attempt_rewriting_take_via_slice
arr = lax.dynamic_slice(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Operation creation failed
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "tmp.py", line 125, in <module>
print(ode_fun_jit(x1, x2))
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jaxlib\mlir\dialects\_stablehlo_ops_gen.py", line 2487, in dynamic_slice
return _get_op_result_or_op_results(DynamicSliceOp(operand=operand, start_indices=start_indices, slice_sizes=slice_sizes, loc=loc, ip=ip))
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jaxlib\mlir\dialects\_stablehlo_ops_gen.py", line 2461, in __init__
super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
ValueError: Operation creation failed
The text was updated successfully, but these errors were encountered:
diffrax 0.6.0
jax 0.4.30
Using jax.experimental.enable_x64 and jit will return a exception.
Exception raised:
The text was updated successfully, but these errors were encountered: