Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qjit fails when multiple tapes share the same loop object #1012

Open
paul0403 opened this issue Aug 13, 2024 · 4 comments
Open

qjit fails when multiple tapes share the same loop object #1012

paul0403 opened this issue Aug 13, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@paul0403
Copy link
Contributor

paul0403 commented Aug 13, 2024

dev = qml.device("lightning.qubit", wires=2)

def my_quantum_transform(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTape], Callable):
    tape1 = tape
    tape2 = qml.tape.QuantumTape(tape.operations, tape.measurements)
    def post_processing_fn(results):
        return results[0] + results[1]
    return [tape1, tape2], post_processing_fn

dispatched_transform = qml.transform(my_quantum_transform)

@qml.qnode(dev)
def circuit():
    @catalyst.for_loop(0, 1, 1)
    def loop0(_, yy):
        qml.RX(3.14, wires=0)
        return yy + 2
    loop0(0)
    return qml.expval(qml.X(0))


circuit = dispatched_transform(circuit)
circuit = qjit(circuit)
print("qjit results: ", circuit())

>>>
Traceback (most recent call last):
  File "/home/paul.wang/catalyst_new/catalyst/multi_tape.py", line 144, in <module>
    circuit = qjit(circuit)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 376, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 443, in __init__
    self.aot_compile()
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 481, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 606, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 531, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 604, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 584, in closure
    return QFunc.__call__(qnode, *args, **dict(params, **kwargs))
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/qfunc.py", line 165, in __call__
    res_flat = func_p.bind(flattened_fun, *args_flat, fn=self)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/qfunc.py", line 143, in _eval_quantum
    closed_jaxpr, out_type, out_tree, out_tree_exp = trace_quantum_function(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 1162, in trace_quantum_function
    qrp_out = trace_quantum_operations(tape, device, qreg_in, ctx, trace, mcm_config)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 655, in trace_quantum_operations
    qrp2 = op.trace_quantum(ctx, device, trace, qrp, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/api_extensions/control_flow.py", line 1209, in trace_quantum
    op.bind_overwrite_classical_tracers(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 460, in bind_overwrite_classical_tracers
    out_quantum_tracer = self.binder(*in_expanded_tracers, **kwargs)[-1]
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/tracing.py", line 969, in bind
    source_info = jax_current()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError: ({ lambda ; a:i64[] b:i64[] c:AbstractQreg() d:AbstractQreg(). let
    e:i64[] = add b 2
    f:AbstractQbit() = qextract c 0
    g:AbstractQbit() = qinst[
      adjoint=False
      ctrl_len=0
      op=RX
      params_len=1
      qubits_len=1
    ] f 3.14
    _:AbstractQreg() = qinsert c 0 g
    h:AbstractQbit() = qextract d 0
    i:AbstractQbit() = qinst[
      adjoint=False
      ctrl_len=0
      op=RX
      params_len=1
      qubits_len=1
    ] h 3.14
    j:AbstractQreg() = qinsert d 0 i
  in (e, j) }, ([<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x722b91bb37f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb0b70>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb25b0>]))

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/paul.wang/catalyst_new/catalyst/multi_tape.py", line 144, in <module>
    circuit = qjit(circuit)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 376, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 443, in __init__
    self.aot_compile()
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 486, in aot_compile
    self.mlir_module, self.mlir = self.generate_ir()
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 621, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 558, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 608, in _func_lowering
    func_op = _func_def_lowering(ctx.module_context, fn, call_jaxpr, name_stack=ctx.name_stack)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 566, in _func_def_lowering
    func_op = mlir.lower_jaxpr_to_fun(ctx, fn.__name__, call_jaxpr, tuple(), name_stack=name_stack)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 2022, in _for_loop_lowering
    out, _ = mlir.jaxpr_subcomp(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1574, in jaxpr_subcomp
    assert len(args) == len(jaxpr.invars), (jaxpr, args)
AssertionError: ({ lambda ; a:i64[] b:i64[] c:AbstractQreg() d:AbstractQreg(). let
    e:i64[] = add b 2
    f:AbstractQbit() = qextract c 0
    g:AbstractQbit() = qinst[
      adjoint=False
      ctrl_len=0
      op=RX
      params_len=1
      qubits_len=1
    ] f 3.14
    _:AbstractQreg() = qinsert c 0 g
    h:AbstractQbit() = qextract d 0
    i:AbstractQbit() = qinst[
      adjoint=False
      ctrl_len=0
      op=RX
      params_len=1
      qubits_len=1
    ] h 3.14
    j:AbstractQreg() = qinsert d 0 i
  in (e, j) }, ([<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x722b91bb37f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb0b70>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb25b0>]))
@paul0403 paul0403 added the bug Something isn't working label Aug 13, 2024
@josh146
Copy link
Member

josh146 commented Aug 18, 2024

Thanks @paul0403! I just wanted to check with @erick-xanadu, is this expected with our current support for quantum transforms?

If I recall correctly, we integrated PL tape transforms assuming that the QNode being transformed was a straight line program only (e.g., that no for loops or conditionals were present).

@dime10
Copy link
Collaborator

dime10 commented Aug 19, 2024

Thanks @paul0403! I just wanted to check with @erick-xanadu, is this expected with our current support for quantum transforms?

If I recall correctly, we integrated PL tape transforms assuming that the QNode being transformed was a straight line program only (e.g., that no for loops or conditionals were present).

We already disallow MCMs with transforms that produce multiple tapes, but I think we are not restricted enough. Disallowing all hybrid ops for the moment should prevent the issue at least.

@josh146
Copy link
Member

josh146 commented Aug 19, 2024

So the fix here would simply be additional validation? Sounds good!

@erick-xanadu
Copy link
Contributor

@josh146 yes, the assumption is that QNode being transformed was a straight line program only. I think there was some initial validation, but perhaps it did not cover all possible cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants