Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad does not work when using dynamic one-shot #1092

Open
mehrdad2m opened this issue Sep 3, 2024 · 3 comments
Open

grad does not work when using dynamic one-shot #1092

mehrdad2m opened this issue Sep 3, 2024 · 3 comments

Comments

@mehrdad2m
Copy link
Contributor

Issue description

using grad does not work when using dynamic one-shot.

  • Actual behavior:
    Crash happens in the following code:
@qml.qnode(dev, diff_method="best", mcm_method="one-shot")
def f(x: float):
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliZ(wires=0))

@qjit
def grad_f(x):
    return grad(f, method="auto")(x)

print(grad_f(1.0))

which crashes with the following message:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 40, in <module>
    print(grad_f(1.0))
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 457, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 528, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 610, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 537, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 608, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 38, in grad_f
    return grad(f, method="auto")(x)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 688, in __call__
    results = grad_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 40, in <module>
    print(grad_f(1.0))
          ^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 457, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 532, in jit_compile
    self.mlir_module, self.mlir = self.generate_ir()
                                  ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 625, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 564, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 711, in _grad_lowering
    attr = ir.DenseElementsAttr.get(nparray, type=const_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
  • Expected behavior: Same circuit without mcm_method="one-shot"

    dev = qml.device('lightning.qubit', wires=1, shots=5)
    @qml.qnode(dev, diff_method="best")
    def g(x: float):
       qml.RX(x, wires=0)
       return qml.expval(qml.PauliZ(wires=0))
    @qjit
    def grad_g(x):
       return grad(g, method="auto")(x)
    

    returns

     -0.4
    
  • Reproduces how often: 100%

  • System information:

    Name: PennyLane
    Version: 0.38.0.dev24
    Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
    Home-page: https://github.com/PennyLaneAI/pennylane
    Author: 
    Author-email: 
    License: Apache License 2.0
    Location: /Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages
    Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
    Required-by: amazon-braket-pennylane-plugin, PennyLane-Catalyst, pennylane-qrack, PennyLane_Lightning, PennyLane_Lightning_Kokkos
    
    Platform info:           macOS-14.6.1-arm64-arm-64bit
    Python version:          3.12.4
    Numpy version:           1.26.4
    Scipy version:           1.12.0
    
@mehrdad2m
Copy link
Contributor Author

mehrdad2m commented Sep 3, 2024

Same crash exists for value_and_grad:

def workflow(x: float):
    @qml.qnode(qml.device("lightning.qubit", wires=3, shots=10), mcm_method="one-shot")
    def circuit1():
        qml.CNOT(wires=[0, 1])
        qml.RX(0, wires=[2])
        return qml.probs()  # This is [1, 0, 0, ...]

    return x * (circuit1()[0])
result2 = qjit(value_and_grad(workflow))(3.0)

@mehrdad2m
Copy link
Contributor Author

mehrdad2m commented Sep 3, 2024

Same crash happens when using jvp and vjp:

x, t = (
    [-0.1, 0.5],
    [0.1, 0.33],
)

def circuit_rx(x1, x2):
    """A test quantum function"""
    qml.RX(x1, wires=0)
    qml.RX(x2, wires=0)
    return qml.expval(qml.PauliY(0))

@qjit
def C_workflow():
    f = qml.QNode(circuit_rx, device=dev, mcm_method="one-shot")
    return C_jvp(f, x, t, method="auto", argnums=list(range(len(x))))


r1 = C_workflow()

the traceback is a little different but essentially the same problem:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 34, in <module>
    @qjit
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 377, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 445, in __init__
    self.aot_compile()
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 483, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 610, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 537, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 608, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 37, in C_workflow
    return C_jvp(f, x, t, method="auto", argnums=list(range(len(x))))
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 488, in jvp
    results = jvp_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: unimplemented array format conversion from format: ?

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 34, in <module>
    @qjit
     ^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 377, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 445, in __init__
    self.aot_compile()
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 488, in aot_compile
    self.mlir_module, self.mlir = self.generate_ir()
                                  ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 625, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 564, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 821, in _jvp_lowering
    StableHLOConstantOp(ir.DenseElementsAttr.get(np.asarray(const))).results
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: unimplemented array format conversion from format: ?

@mehrdad2m
Copy link
Contributor Author

The following tests have already been added to catalyst and are marked xfail or skip. Any solution to this issue should pass these tests at catalyst/frontend/test/pytest/test_mid_circuit_measurement.py :

test_mcm_method_with_grad
test_mcm_method_with_value_and_grad
test_mcm_method_with_jvp
test_mcm_method_with_jvp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant