You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There are several places where we modify JAX during compilation.
Just to list some:
# Required for JAX tracer objects as PennyLane wires.# pylint: disable=unnecessary-lambdasetattr(jax.interpreters.partial_eval.DynamicJaxprTracer, "__hash__", lambdax: id(x))
# This flag cannot be set in ``QJIT.get_mlir()`` because values created before# that function is called must be consistent with the JAX configuration value.jax.config.update("jax_enable_x64", True)
Patchers... (see jax_extras, see jax_transient_config)
And we also have a global context to see whether or not we are running or jax via the EvaluationContext.
With callbacks, this now changes the assumption that if we are tracing, we will never go back to the python environment. We should have a function that is able to save the configuration before we trace, change however we want it, reset it during callbacks, and reset it back to what we need once we exit the callback scope.
Note, could we instead of changing jax.interpreters.partial_eval.DynamicJaxprTracer and adding a hash, can't we change pennylane wire utilities to find whether the wire is jax.interpreters.partial_eval.DynamicJaxprTracer and compute the id as its hash instead of modifying jax itself?
The text was updated successfully, but these errors were encountered:
Note, could we instead of changing jax.interpreters.partial_eval.DynamicJaxprTracer and adding a hash, can't we change pennylane wire utilities to find whether the wire is jax.interpreters.partial_eval.DynamicJaxprTracer and compute the id as its hash instead of modifying jax itself?
Yes I think that's a good idea! The capture module will probably take care of this.
The bug is that we are changing the global state and we shouldn't 😅 Just changing the state when we shouldn't. The concrete error is #894 but that is just a symptom.
There are several places where we modify JAX during compilation.
Just to list some:
Patchers... (see jax_extras, see jax_transient_config)
And we also have a global context to see whether or not we are running or jax via the EvaluationContext.
With callbacks, this now changes the assumption that if we are tracing, we will never go back to the python environment. We should have a function that is able to save the configuration before we trace, change however we want it, reset it during callbacks, and reset it back to what we need once we exit the callback scope.
Note, could we instead of changing
jax.interpreters.partial_eval.DynamicJaxprTracer
and adding a hash, can't we change pennylane wire utilities to find whether the wire isjax.interpreters.partial_eval.DynamicJaxprTracer
and compute theid
as itshash
instead of modifying jax itself?The text was updated successfully, but these errors were encountered: