Closed
Description
import jax.numpy as jnp
import pennylane as qml
from catalyst import *
@qml.qnode(dev, mcm_method="one-shot")
def circuit():
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.Hadamard(wires=0)
return qml.expval(qml.PauliY(0))
@qjit
def mitigated_circuit():
s = jax.numpy.array([1, 2])
return mitigate_with_zne(circuit, scale_factors=s)()
File "catalyst/frontend/catalyst/jax_primitives.py", line 739, in _zne_lowering
_func_lowering(ctx, *args, call_jaxpr=jaxpr.eqns[0].params["call_jaxpr"], fn=fn, call=False)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
KeyError: 'call_jaxpr'