Skip to content

Dynamic one shot does not work with ZNE #929

Closed
@dime10

Description

@dime10
import jax.numpy as jnp
import pennylane as qml
from catalyst import *

@qml.qnode(dev, mcm_method="one-shot")
def circuit():
    qml.Hadamard(wires=0)
    qml.CNOT(wires=[0, 1])
    qml.Hadamard(wires=0)
    qml.CNOT(wires=[0, 1])
    qml.Hadamard(wires=0)
    return qml.expval(qml.PauliY(0))


@qjit
def mitigated_circuit():
    s = jax.numpy.array([1, 2])
    return mitigate_with_zne(circuit, scale_factors=s)()
  File "catalyst/frontend/catalyst/jax_primitives.py", line 739, in _zne_lowering
    _func_lowering(ctx, *args, call_jaxpr=jaxpr.eqns[0].params["call_jaxpr"], fn=fn, call=False)
                                          ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
KeyError: 'call_jaxpr'

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions