Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function for setting and resetting global changes during compilation #913

Open
erick-xanadu opened this issue Jul 9, 2024 · 2 comments
Open
Labels
bug Something isn't working chore enhancement New feature or request

Comments

@erick-xanadu
Copy link
Contributor

There are several places where we modify JAX during compilation.

Just to list some:

# Required for JAX tracer objects as PennyLane wires.
# pylint: disable=unnecessary-lambda
setattr(jax.interpreters.partial_eval.DynamicJaxprTracer, "__hash__", lambda x: id(x))

# This flag cannot be set in ``QJIT.get_mlir()`` because values created before
# that function is called must be consistent with the JAX configuration value.
jax.config.update("jax_enable_x64", True)

Patchers... (see jax_extras, see jax_transient_config)

And we also have a global context to see whether or not we are running or jax via the EvaluationContext.

With callbacks, this now changes the assumption that if we are tracing, we will never go back to the python environment. We should have a function that is able to save the configuration before we trace, change however we want it, reset it during callbacks, and reset it back to what we need once we exit the callback scope.

Note, could we instead of changing jax.interpreters.partial_eval.DynamicJaxprTracer and adding a hash, can't we change pennylane wire utilities to find whether the wire is jax.interpreters.partial_eval.DynamicJaxprTracer and compute the id as its hash instead of modifying jax itself?

@erick-xanadu erick-xanadu added bug Something isn't working enhancement New feature or request chore labels Jul 9, 2024
@dime10
Copy link
Contributor

dime10 commented Jul 9, 2024

What is the bug?

Note, could we instead of changing jax.interpreters.partial_eval.DynamicJaxprTracer and adding a hash, can't we change pennylane wire utilities to find whether the wire is jax.interpreters.partial_eval.DynamicJaxprTracer and compute the id as its hash instead of modifying jax itself?

Yes I think that's a good idea! The capture module will probably take care of this.

@erick-xanadu
Copy link
Contributor Author

What is the bug?

The bug is that we are changing the global state and we shouldn't 😅 Just changing the state when we shouldn't. The concrete error is #894 but that is just a symptom.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working chore enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants