diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 999e5ad13..2f97fec69 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -475,6 +475,10 @@ jobs: python3 -m pip install oqc-qcaas-client make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -558,6 +562,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -620,6 +628,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 diff --git a/frontend/catalyst/from_plxpr/__init__.py b/frontend/catalyst/from_plxpr/__init__.py index a39039fe1..3599025f6 100644 --- a/frontend/catalyst/from_plxpr/__init__.py +++ b/frontend/catalyst/from_plxpr/__init__.py @@ -15,4 +15,4 @@ """Conversion from plxpr to catalyst jaxpr""" from catalyst.from_plxpr.control_flow import handle_cond, handle_for_loop, handle_while_loop -from catalyst.from_plxpr.from_plxpr import from_plxpr, register_transform, trace_from_pennylane +from catalyst.from_plxpr.from_plxpr import from_plxpr, trace_from_pennylane diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 64d74400b..16997f3cf 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -1,4 +1,4 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. +# Copyright 2022-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,502 +11,660 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ -This submodule defines a utility for converting plxpr into Catalyst jaxpr. +This module contains a patch for the upstream qml.QNode behaviour, in particular around +what happens when a QNode object is called during tracing. Mostly this involves bypassing +the default behaviour and replacing it with a function-like "QNode" primitive. """ -# pylint: disable=protected-access - +import logging +from copy import copy +from dataclasses import dataclass, replace +from typing import Callable, Sequence -import warnings -from functools import partial -from typing import Callable - -import jax +import jax.numpy as jnp import pennylane as qml -from jax.extend.core import ClosedJaxpr, Jaxpr -from jax.extend.linear_util import wrap_init -from pennylane.capture import PlxprInterpreter, qnode_prim -from pennylane.capture.expand_transforms import ExpandTransformsInterpreter -from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires -from pennylane.transforms import cancel_inverses as pl_cancel_inverses -from pennylane.transforms import commute_controlled as pl_commute_controlled -from pennylane.transforms import decompose as pl_decompose -from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding -from pennylane.transforms import merge_rotations as pl_merge_rotations -from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion -from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot - -from catalyst.device import extract_backend_info -from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter -from catalyst.jax_extras import make_jaxpr2, transient_jax_config -from catalyst.jax_primitives import ( - device_init_p, - device_release_p, - qalloc_p, - qdealloc_p, - quantum_kernel_p, +from jax.core import eval_jaxpr +from jax.tree_util import tree_flatten, tree_unflatten +from pennylane import exceptions +from pennylane.measurements import CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP +from pennylane.transforms.dynamic_one_shot import ( + gather_non_mcm, + init_auxiliary_tape, + parse_native_mid_circuit_measurements, ) -from catalyst.passes.pass_api import Pass -from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter -from .qubit_handler import ( - QubitHandler, - QubitIndexRecorder, -) +import catalyst +from catalyst.api_extensions import MidCircuitMeasure +from catalyst.device import QJITDevice +from catalyst.device.qjit_device import is_dynamic_wires +from catalyst.jax_extras import deduce_avals, get_implicit_and_explicit_flat_args, unzip2 +from catalyst.jax_extras.tracing import uses_transform +from catalyst.jax_primitives import quantum_kernel_p +from catalyst.jax_tracer import Function, trace_quantum_function +from catalyst.logging import debug_logger +from catalyst.passes.pass_api import dictionary_to_list_of_passes, Pass +from catalyst.tracing.contexts import EvaluationContext +from catalyst.tracing.type_signatures import filter_static_args +from catalyst.utils.exceptions import CompileError + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +@dataclass +class OutputContext: + """Context containing parameters needed for finalizing quantum function output.""" + + cpy_tape: any + classical_values: any + classical_return_indices: any + out_tree_expected: any + snapshots: any + shot_vector: any + num_mcm: int + + +def _resolve_mcm_config(mcm_config, shots): + """Helper function for resolving and validating that the mcm_config is valid for executing.""" + updated_values = {} + + updated_values["postselect_mode"] = ( + None if isinstance(shots, int) and shots == 0 else mcm_config.postselect_mode + ) + if mcm_config.mcm_method is None: + updated_values["mcm_method"] = "one-shot" + if mcm_config.mcm_method == "deferred": + raise ValueError("mcm_method='deferred' is not supported with Catalyst.") + if ( + mcm_config.mcm_method == "single-branch-statistics" + and mcm_config.postselect_mode == "hw-like" + ): + raise ValueError( + "Cannot use postselect_mode='hw-like' with Catalyst when mcm_method != 'one-shot'." + ) + if mcm_config.mcm_method == "one-shot" and shots == 0: + raise ValueError( + "Cannot use the 'one-shot' method for mid-circuit measurements with analytic mode." + ) + + return replace(mcm_config, **updated_values) + +def _get_total_shots(qnode): + """ + Extract total shots from qnode. + If shots is None on the qnode, this method returns 0 (static). + This method allows the qnode shots to be either static (python int + literals) or dynamic (tracers). + """ + # due to possibility of tracer, we cannot use a simple `or` here to simplify + shots_value = qnode._shots.total_shots # pylint: disable=protected-access + if shots_value is None: + shots = 0 + else: + shots = shots_value + return shots + + +def _is_one_shot_compatible_device(qnode): + device_name = qnode.device.name + exclude_devices = {"softwareq.qpp", "nvidia.custatevec", "nvidia.cutensornet"} + + # Check device name against exclude list + if device_name in exclude_devices: + return False + + # Additional check for OQDDevice class + device_class_name = qnode.device.__class__.__name__ + return device_class_name != "OQDDevice" + + +def configure_mcm_and_try_one_shot(qnode, args, kwargs): + """Configure mid-circuit measurement settings and handle one-shot execution.""" + dynamic_one_shot_called = getattr(qnode, "_dynamic_one_shot_called", False) + if not dynamic_one_shot_called: + mcm_config = copy( + qml.devices.MCMConfig( + postselect_mode=qnode.execute_kwargs["postselect_mode"], + mcm_method=qnode.execute_kwargs["mcm_method"], + ) + ) + total_shots = _get_total_shots(qnode) + user_specified_mcm_method = mcm_config.mcm_method + mcm_config = _resolve_mcm_config(mcm_config, total_shots) + + # Check if measurements_from_{samples/counts} is being used + uses_measurements_from_samples = uses_transform(qnode, "measurements_from_samples") + uses_measurements_from_counts = uses_transform(qnode, "measurements_from_counts") + has_finite_shots = isinstance(total_shots, int) and total_shots > 0 + + # For cases that user are not tend to executed with one-shot, and facing + # 1. non-one-shot compatible device, + # 2. non-finite shots, + # 3. measurement transform, + # fallback to single-branch-statistics + one_shot_compatible = _is_one_shot_compatible_device(qnode) + one_shot_compatible &= has_finite_shots + one_shot_compatible &= not uses_measurements_from_samples + one_shot_compatible &= not uses_measurements_from_counts + + should_fallback = ( + not one_shot_compatible + and user_specified_mcm_method is None + and mcm_config.mcm_method == "one-shot" + ) -def _get_device_kwargs(device) -> dict: - """Calulcate the params for a device equation.""" - info = extract_backend_info(device) - # Note that the value of rtd_kwargs is a string version of - # the info kwargs, not the info kwargs itself - # this is due to ease of serialization to MLIR - return { - "rtd_kwargs": str(info.kwargs), - "rtd_lib": info.lpath, - "rtd_name": info.c_interface_name, - } + if should_fallback: + mcm_config = replace(mcm_config, mcm_method="single-branch-statistics") + if mcm_config.mcm_method == "one-shot": + # If measurements_from_samples/counts while one-shot is used, raise an error + if uses_measurements_from_samples: + raise CompileError("measurements_from_samples is not supported with one-shot") + if uses_measurements_from_counts: + raise CompileError("measurements_from_counts is not supported with one-shot") -# code example has long lines -# pylint: disable=line-too-long -def from_plxpr(plxpr: ClosedJaxpr) -> Callable[..., Jaxpr]: - """Convert PennyLane variant jaxpr to Catalyst variant jaxpr. + mcm_config = replace( + mcm_config, postselect_mode=mcm_config.postselect_mode or "hw-like" + ) - Args: - jaxpr (ClosedJaxpr): PennyLane variant jaxpr + try: + return Function(dynamic_one_shot(qnode, mcm_config=mcm_config))(*args, **kwargs) + except (TypeError, ValueError, CompileError, NotImplementedError) as e: + # If user specified mcm_method, we can't fallback to single-branch-statistics, + # reraise the original error + if user_specified_mcm_method is not None: + raise + + # Fallback only if mcm was auto-determined + error_msg = str(e) + unsupported_measurement_error = any( + pattern in error_msg + for pattern in [ + "Native mid-circuit measurement mode does not support", + "qml.var(obs) cannot be returned when `mcm_method='one-shot'`", + "empty wires is not supported with dynamic wires in one-shot mode", + "No need to run one-shot mode", + ] + ) + + # Fallback if error is related to unsupported measurements + if unsupported_measurement_error: + logger.warning("Fallback to single-branch-statistics: %s", e) + mcm_config = replace(mcm_config, mcm_method="single-branch-statistics") + else: + raise + return None + +def _reconstruct_output_with_classical_values( + measurement_results, classical_values, classical_return_indices +): + """ + Reconstruct the output values from the classical values and measurement results. + Args: + out: Output from measurement processing + classical_values: Classical values + classical_return_indices: Indices of classical values Returns: - Callable: A function that accepts the same arguments as the plxpr and returns catalyst - variant jaxpr. + results: Reconstructed output with classical values inserted + """ + if not classical_values: + return measurement_results - Note that the input jaxpr should be workflow level and contain qnode primitives, rather than - qfunc level with individual operators. + total_expected = len(classical_values) + len(measurement_results) + classical_iter = iter(classical_values) + measurement_iter = iter(measurement_results) - .. code-block:: python + def get_next_value(idx): + return next(classical_iter) if idx in classical_return_indices else next(measurement_iter) + + results = [get_next_value(i) for i in range(total_expected)] + return results - from catalyst.from_plxpr import from_plxpr - - qml.capture.enable() - - @qml.qnode(qml.device('lightning.qubit', wires=2)) - def circuit(x): - qml.RX(x, 0) - return qml.probs(wires=(0, 1)) - - def f(x): - return circuit(2 * x) ** 2 - - plxpr = jax.make_jaxpr(circuit)(0.5) - - print(from_plxpr(plxpr)(0.5)) - - .. code-block:: none - - { lambda ; a:f64[]. let - b:f64[4] = func[ - call_jaxpr={ lambda ; c:f64[]. let - device_init[ - rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None} - rtd_lib=*** - rtd_name=LightningSimulator - ] - d:AbstractQreg() = qalloc 2 - e:AbstractQbit() = qextract d 0 - f:AbstractQbit() = qinst[ - adjoint=False - ctrl_len=0 - op=RX - params_len=1 - qubits_len=1 - ] e c - g:AbstractQbit() = qextract d 1 - h:AbstractObs(num_qubits=2,primitive=compbasis) = compbasis f g - i:f64[4] = probs[shape=(4,) shots=None] h - j:AbstractQreg() = qinsert d 0 f - qdealloc j - in (i,) } - qnode= - ] a - in (b,) } +def _extract_classical_and_measurement_results(results, classical_return_indices): + """ + Split results into classical values and measurement results. + It assume that the results are in the order of classical values and measurement results. """ - return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) + num_classical_return_indices = len(classical_return_indices) + classical_values = results[:num_classical_return_indices] + measurement_results = results[num_classical_return_indices:] + return classical_values, measurement_results -class WorkflowInterpreter(PlxprInterpreter): - """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant.""" +def _finalize_output(out, ctx: OutputContext): + """ + Finalize the output by reconstructing with classical values and unflattening to the + expected tree structure. + Args: + out: The output to finalize + context: OutputContext containing all necessary parameters for finalization + """ + # Handle case with no measurements + if len(ctx.cpy_tape.measurements) == 0: + out = out[: -ctx.num_mcm] - def __init__(self): - self._pass_pipeline = [] - self.init_qreg = None + out = _reconstruct_output_with_classical_values( + out, ctx.classical_values, ctx.classical_return_indices + ) - # Compiler options for the new decomposition system - self.requires_decompose_lowering = False - self.decompose_tkwargs = {} # target gateset + out_tree_expected = ctx.out_tree_expected + if ctx.snapshots is not None: + out = (out[0], tree_unflatten(out_tree_expected[1], out[1])) + else: + out = tree_unflatten(out_tree_expected[0], out) + return out - super().__init__() +class QFunc: + """A device specific quantum function. -# pylint: disable=unused-argument, too-many-arguments -@WorkflowInterpreter.register_primitive(qnode_prim) -def handle_qnode( - self, *args, qnode, device, shots_len, execution_config, qfunc_jaxpr, n_consts, batch_dims=None -): - """Handle the conversion from plxpr to Catalyst jaxpr for the qnode primitive""" + Args: + qfunc (Callable): the quantum function + shots (int): How many times the circuit should be evaluated (or sampled) to estimate + the expectation values + device (a derived class from QubitDevice): a device specification which determines + the valid gate set for the quantum function + """ - self.qubit_index_recorder = QubitIndexRecorder() + def __new__(cls): + raise NotImplementedError() # pragma: no-cover - if shots_len > 1: - raise NotImplementedError("shot vectors are not yet supported for catalyst conversion.") + # pylint: disable=no-member + # pylint: disable=self-cls-assignment + @debug_logger + def __call__(self, *args, **kwargs): - shots = args[0] if shots_len else 0 - consts = args[shots_len : n_consts + shots_len] - non_const_args = args[shots_len + n_consts :] + if EvaluationContext.is_quantum_tracing(): + raise CompileError("Can't nest qnodes under qjit") - closed_jaxpr = ( - ClosedJaxpr(qfunc_jaxpr, consts) - if not self.requires_decompose_lowering - else _apply_compiler_decompose_to_plxpr( - inner_jaxpr=qfunc_jaxpr, - consts=consts, - ncargs=non_const_args, - tgateset=list(self.decompose_tkwargs.get("gate_set", [])), - ) - ) + assert isinstance(self, qml.QNode) - graph_succeeded = False - if self.requires_decompose_lowering: - closed_jaxpr, graph_succeeded = _collect_and_compile_graph_solutions( - inner_jaxpr=closed_jaxpr.jaxpr, - consts=closed_jaxpr.consts, - tkwargs=self.decompose_tkwargs, - ncargs=non_const_args, - ) + new_transform_program, new_pipeline = _extract_passes(self.transform_program) - # Fallback to the legacy decomposition if the graph-based decomposition failed - if not graph_succeeded: - # Remove the decompose-lowering pass from the pipeline - self._pass_pipeline = [p for p in self._pass_pipeline if p.name != "decompose-lowering"] - closed_jaxpr = _apply_compiler_decompose_to_plxpr( - inner_jaxpr=closed_jaxpr.jaxpr, - consts=closed_jaxpr.consts, - ncargs=non_const_args, - tkwargs=self.decompose_tkwargs, - ) + # Update the qnode with peephole pipeline + pass_pipeline = kwargs.pop("pass_pipeline", ()) or () + pass_pipeline += new_pipeline + pass_pipeline = dictionary_to_list_of_passes(pass_pipeline) + new_qnode = copy(self) + new_qnode._transform_program = new_transform_program # pylint: disable=protected-access + + # Mid-circuit measurement configuration/execution + fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs) + + # If the qnode is failed to execute as one-shot, fn_result will be None + if fn_result is not None: + return fn_result + + new_device = copy(new_qnode.device) + qjit_device = QJITDevice(new_device) + + static_argnums = kwargs.pop("static_argnums", ()) + out_tree_expected = kwargs.pop("_out_tree_expected", []) + classical_return_indices = kwargs.pop("_classical_return_indices", []) + num_mcm_expected = kwargs.pop("_num_mcm_expected", []) + debug_info = kwargs.pop("debug_info", None) - def calling_convention(*args): - device_init_p.bind( - shots, - auto_qubit_management=(device.wires is None), - **_get_device_kwargs(device), + print(new_qnode.transform_program) + print(pass_pipeline) + + def _eval_quantum(*args, **kwargs): + trace_result = trace_quantum_function( + new_qnode.func, + qjit_device, + args, + kwargs, + new_qnode, + static_argnums, + debug_info, + ) + closed_jaxpr = trace_result.closed_jaxpr + out_type = trace_result.out_type + out_tree = trace_result.out_tree + out_tree_exp = trace_result.return_values_tree + cls_ret_idx = trace_result.classical_return_indices + num_mcm = trace_result.num_mcm + + out_tree_expected.append(out_tree_exp) + classical_return_indices.append(cls_ret_idx) + num_mcm_expected.append(num_mcm) + dynamic_args = filter_static_args(args, static_argnums) + args_expanded = get_implicit_and_explicit_flat_args(None, *dynamic_args, **kwargs) + res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded) + _, out_keep = unzip2(out_type) + res_flat = [r for r, k in zip(res_expanded, out_keep) if k] + return tree_unflatten(out_tree, res_flat) + + flattened_fun, _, _, out_tree_promise = deduce_avals( + _eval_quantum, args, kwargs, static_argnums, debug_info ) - qreg = qalloc_p.bind(len(device.wires)) - self.init_qreg = QubitHandler(qreg, self.qubit_index_recorder) - converter = PLxPRToQuantumJaxprInterpreter( - device, shots, self.init_qreg, {}, self.qubit_index_recorder + dynamic_args = filter_static_args(args, static_argnums) + args_flat = tree_flatten((dynamic_args, kwargs))[0] + res_flat = quantum_kernel_p.bind( + flattened_fun, *args_flat, qnode=self, pipeline=tuple(pass_pipeline) ) - retvals = converter(closed_jaxpr, *args) - self.init_qreg.insert_all_dangling_qubits() - qdealloc_p.bind(self.init_qreg.get()) - device_release_p.bind() - return retvals - - if self.requires_decompose_lowering and graph_succeeded: - # Add gate_set attribute to the quantum kernel primitive - # decompose_gatesets is treated as a queue of gatesets to be used - # but we only support a single gateset for now in from_plxpr - # as supporting multiple gatesets requires an MLIR/C++ graph-decomposition - # implementation. The current Python implementation cannot be mixed - # with other transforms in between. - gateset = [_get_operator_name(op) for op in self.decompose_tkwargs.get("gate_set", [])] - setattr(qnode, "decompose_gatesets", [gateset]) - - return quantum_kernel_p.bind( - wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), - *non_const_args, - qnode=qnode, - pipeline=self._pass_pipeline, + return tree_unflatten(out_tree_promise(), res_flat)[0] + + +# pylint: disable=protected-access +def _get_shot_vector(qnode): + shot_vector = qnode._shots.shot_vector if qnode._shots else [] + return ( + shot_vector + if len(shot_vector) > 1 or any(copies > 1 for _, copies in shot_vector) + else None ) -# The map below describes the parity between PL transforms and Catalyst passes. -# PL transforms having a Catalyst pass counterpart will have a name as value, -# otherwise their value will be None. The second value indicates if the transform -# requires decomposition to be supported by Catalyst. -transforms_to_passes = { - pl_cancel_inverses: ("remove-chained-self-inverse", False), - pl_commute_controlled: (None, False), - pl_decompose: (None, False), - pl_map_wires: (None, False), - pl_merge_amplitude_embedding: (None, True), - pl_merge_rotations: ("merge-rotations", False), - pl_single_qubit_fusion: (None, False), - pl_unitary_to_rot: (None, False), -} - - -# pylint: disable-next=redefined-outer-name -def register_transform(pl_transform, pass_name, decomposition): - """Register pennylane transforms and their conversion to Catalyst transforms""" - - # pylint: disable=too-many-arguments - @WorkflowInterpreter.register_primitive(pl_transform._primitive) - def handle_transform( - self, - *args, - args_slice, - consts_slice, - inner_jaxpr, - targs_slice, - tkwargs, - catalyst_pass_name=pass_name, - requires_decomposition=decomposition, - pl_plxpr_transform=pl_transform._plxpr_transform, - ): - """Handle the conversion from plxpr to Catalyst jaxpr for a - PL transform.""" - consts = args[consts_slice] - non_const_args = args[args_slice] - targs = args[targs_slice] - - # If the transform is a decomposition transform - # and the graph-based decomposition is enabled - if ( - hasattr(pl_plxpr_transform, "__name__") - and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" - and qml.decomposition.enabled_graph() - ): - if not self.requires_decompose_lowering: - self.requires_decompose_lowering = True +def _get_snapshot_results(mcm_method, tape, out): + """ + Get the snapshot results from the tape. + Args: + tape: The tape to get the snapshot results from. + out: The output of the tape. + Returns: + processed_snapshots: The extracted snapshot results if available; + otherwise, returns the original output. + measurement_results: The corresponding measurement results. + """ + # if no snapshot are present, return None, out + assert mcm_method == "one-shot" + + if not any(isinstance(op, qml.Snapshot) for op in tape.operations): + return None, out + + # Snapshots present: out[0] = snapshots, out[1] = measurements + snapshot_results, measurement_results = out + + # Take first shot for each snapshot + processed_snapshots = [ + snapshot[0] if hasattr(snapshot, "shape") and len(snapshot.shape) > 1 else snapshot + for snapshot in snapshot_results + ] + + return processed_snapshots, measurement_results + + +def _reshape_for_shot_vector(mcm_method, result, shot_vector): + assert mcm_method == "one-shot" + + # Calculate the shape for reshaping based on shot vector + result_list = [] + start_idx = 0 + for shot, copies in shot_vector: + # Reshape this segment to (copies, shot, n_wires) + segment = result[start_idx : start_idx + shot * copies] + if copies > 1: + segment_shape = (copies, shot, result.shape[-1]) + segment = jnp.reshape(segment, segment_shape) + result_list.extend([segment[i] for i in range(copies)]) + else: + result_list.append(segment) + start_idx += shot * copies + result = tuple(result_list) + return result + + +def _process_terminal_measurements(mcm_method, cpy_tape, out, snapshots, shot_vector): + """Process measurements when there are no mid-circuit measurements.""" + assert mcm_method == "one-shot" + + # flatten the outs structure + out, _ = tree_flatten(out) + new_out = [] + idx = 0 + + for m in cpy_tape.measurements: + if isinstance(m, CountsMP): + if isinstance(out[idx], tuple) and len(out[idx]) == 2: + # CountsMP result is stored as (keys, counts) tuple + keys, counts = out[idx] + idx += 1 else: - raise NotImplementedError( - "Multiple decomposition transforms are not yet supported." - ) + keys = out[idx] + counts = out[idx + 1] + idx += 2 + + if snapshots is not None: + counts_array = jnp.stack(counts, axis=0) + aggregated_counts = jnp.sum(counts_array, axis=0) + counts_result = (keys, aggregated_counts) + else: + aggregated_counts = jnp.sum(counts, axis=0) + counts_result = (keys[0], aggregated_counts) + + new_out.append(counts_result) + continue + + result = jnp.squeeze(out[idx]) + max_ndim = min(len(out[idx].shape), 2) + if out[idx].shape[0] == 1: + # Adding the first axis back when the first axis in the original + # array is 1, since it corresponds to the shot's dimension. + result = jnp.expand_dims(result, axis=0) + if result.ndim == 1 and max_ndim == 2: + result = jnp.expand_dims(result, axis=1) + + # Without MCMs and postselection, all samples are valid for use in MP computation. + is_valid = jnp.full((result.shape[0],), True) + processed_result = gather_non_mcm( + m, result, is_valid, postselect_mode="pad-invalid-samples" + ) - # Update the decompose_gateset to be used by the quantum kernel primitive - # TODO: we originally wanted to treat decompose_gateset as a queue of - # gatesets to be used by the decompose-lowering pass at MLIR - # but this requires a C++ implementation of the graph-based decomposition - # which doesn't exist yet. - self.decompose_tkwargs = tkwargs + # Handle shot vector reshaping for SampleMP + if isinstance(m, SampleMP) and shot_vector is not None: + processed_result = _reshape_for_shot_vector(mcm_method, processed_result, shot_vector) - # Note. We don't perform the compiler-specific decomposition here - # to be able to support multiple decomposition transforms - # and collect all the required gatesets - # as well as being able to support other transforms in between. + new_out.append(processed_result) + idx += 1 - # The compiler specific transformation will be performed - # in the qnode handler. + return (snapshots, tuple(new_out)) if snapshots else tuple(new_out) - # Add the decompose-lowering pass to the start of the pipeline - self._pass_pipeline.insert(0, Pass("decompose-lowering")) - # We still need to construct and solve the graph based on - # the current jaxpr based on the current gateset - # but we don't rewrite the jaxpr at this stage. +def _validate_one_shot_measurements( + mcm_config, tape: qml.tape.QuantumTape, user_specified_mcm_method, shot_vector, wires +) -> None: + """Validate measurements for one-shot mode. - # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) + Args: + mcm_config: The mid-circuit measurement configuration + tape: The quantum tape containing measurements to validate + qnode: The quantum node being transformed - # def gds_wrapper(*args): - # return gds_interpreter.eval(inner_jaxpr, consts, *args) + Raises: + TypeError: If unsupported measurement types are used + NotImplementedError: If measurement configuration is not supported + """ + mcm_method = mcm_config.mcm_method + assert mcm_method == "one-shot" + + # Check if using shot vector with non-SampleMP measurements + has_shot_vector = len(shot_vector) > 1 or any(copies > 1 for _, copies in shot_vector) + has_wires = wires is not None and not is_dynamic_wires(wires) + + # Raise an error if there are no mid-circuit measurements, it will fallback to + # single-branch-statistics + if ( + not any(isinstance(op, MidCircuitMeasure) for op in tape.operations) + and user_specified_mcm_method is None + ): + raise ValueError("No need to run one-shot mode when there are no mid-circuit measurements.") + + for m in tape.measurements: + # Check if measurement type is supported + if not isinstance(m, (CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP)): + raise TypeError( + f"Native mid-circuit measurement mode does not support {type(m).__name__} " + "measurements." + ) - # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) - # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) - return self.eval(inner_jaxpr, consts, *non_const_args) + # Check variance with observable + if isinstance(m, VarianceMP) and m.obs: + raise TypeError( + "qml.var(obs) cannot be returned when `mcm_method='one-shot'` because " + "the Catalyst compiler does not support qml.sample(obs)." + ) - if catalyst_pass_name is None: - # Use PL's ExpandTransformsInterpreter to expand this and any embedded - # transform according to PL rules. It works by overriding the primitive - # registration, making all embedded transforms follow the PL rules - # from now on, hence ignoring the Catalyst pass conversion - def wrapper(*args): - return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) + # Check if the measurement is supported with shot-vector + if has_shot_vector and not isinstance(m, SampleMP): + raise NotImplementedError( + f"Measurement {type(m).__name__} does not support shot-vectors. " + "Use qml.sample() instead." + ) - unravelled_jaxpr = jax.make_jaxpr(wrapper)(*non_const_args) - final_jaxpr = pl_plxpr_transform( - unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *non_const_args + # Check dynamic wires with empty wires + if not has_wires and isinstance(m, (SampleMP, CountsMP)) and (m.wires.tolist() == []): + raise NotImplementedError( + f"Measurement {type(m).__name__} with empty wires is not supported with " + "dynamic wires in one-shot mode. Please specify a constant number of wires on " + "the device." ) - if requires_decomposition: - final_jaxpr = pl_decompose._plxpr_transform( - final_jaxpr.jaxpr, final_jaxpr.consts, targs, tkwargs, *non_const_args - ) - return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) +# pylint: disable=protected-access,no-member,not-callable +def dynamic_one_shot(qnode, **kwargs): + """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. - # Apply the corresponding Catalyst pass counterpart - self._pass_pipeline.insert(0, Pass(catalyst_pass_name)) - return self.eval(inner_jaxpr, consts, *non_const_args) + Args: + qnode (QNode): a quantum circuit which will run ``num_shots`` times + Returns: + qnode (QNode): -# This is our registration factory for PL transforms. The loop below iterates -# across the map above and generates a custom handler for each transform. -# In order to ensure early binding, we pass the PL plxpr transform and the -# Catalyst pass as arguments whose default values are set by the loop. -for pl_transform, (pass_name, decomposition) in transforms_to_passes.items(): - register_transform(pl_transform, pass_name, decomposition) + The transformed circuit to be run ``num_shots`` times such as to simulate dynamic execution. -# pylint: disable=too-many-positional-arguments -def trace_from_pennylane( - fn, static_argnums, dynamic_args, abstracted_axes, sig, kwargs, debug_info=None -): - """Capture the JAX program representation (JAXPR) of the wrapped function, using - PL capure module. + **Example** - Args: - fn(Callable): the user function to be traced - static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the - positions of static arguments. - dynamic_args(Seqence[Any]): the abstract values of the dynamic arguments. - abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]): - An experimental option to specify dynamic tensor shapes. - This option affects the compilation of the annotated function. - Function arguments with ``abstracted_axes`` specified will be compiled to ranked tensors - with dynamic shapes. For more details, please see the Dynamically-shaped Arrays section - below. - sig(Sequence[Any]): a tuple indicating the argument signature of the function. Static arguments - are indicated with their literal values, and dynamic arguments are indicated by abstract - values. - kwargs(Dict[str, Any]): keyword argumemts to the function. - debug_info(jax.api_util.debug_info): a source debug information object required by jaxprs. + Consider the following circuit: - Returns: - ClosedJaxpr: captured JAXPR - Tuple[Tuple[ShapedArray, bool]]: the return type of the captured JAXPR. - The boolean indicates whether each result is a value returned by the user function. - PyTreeDef: PyTree metadata of the function output - Tuple[Any]: the dynamic argument signature - """ + .. code-block:: python - with transient_jax_config({"jax_dynamic_shapes": True}): + dev = qml.device("lightning.qubit", shots=100) + params = np.pi / 4 * np.ones(2) - make_jaxpr_kwargs = { - "static_argnums": static_argnums, - "abstracted_axes": abstracted_axes, - "debug_info": debug_info, - } + @qjit + @dynamic_one_shot + @qml.qnode(dev, diff_method=None) + def circuit(x, y): + qml.RX(x, wires=0) + m0 = measure(0, reset=reset, postselect=postselect) - args = sig + @cond(m0 == 1) + def ansatz(): + qml.RY(y, wires=1) - if isinstance(fn, qml.QNode) and static_argnums: - # `make_jaxpr2` sees the qnode - # The static_argnum on the wrapped function takes precedence over the - # one in `make_jaxpr` - # https://github.com/jax-ml/jax/blob/636691bba40b936b8b64a4792c1d2158296e9dd4/jax/_src/linear_util.py#L231 - # Therefore we need to coordinate them manually - fn.static_argnums = static_argnums + ansatz() + return measure_f(wires=[0, 1]) - plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs) - jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs) + The ``dynamic_one_shot`` decorator prompts the QNode to perform a hundred one-shot + calculations, where in each calculation the ``measure`` operations dynamically + measures the 0-wire and collapse the state vector stochastically. + """ - return jaxpr, out_type, out_treedef, sig + cpy_tape = None + mcm_config = kwargs.pop("mcm_config", None) + def transform_to_single_shot(qnode): + if not qnode._shots: + raise exceptions.QuantumFunctionError( + "dynamic_one_shot is only supported with finite shots." + ) -def _apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, ncargs, tgateset=None, tkwargs=None): - """Apply the compiler-specific decomposition for a given JAXPR. + user_specified_mcm_method = qnode.execute_kwargs["mcm_method"] + shot_vector = qnode._shots.shot_vector if qnode._shots else [] + wires = qnode.device.wires - This function first disables the graph-based decomposition optimization - to ensure that only high-level gates and templates with a single decomposition - are decomposed. It then performs the pre-mlir decomposition using PennyLane's - `plxpr_transform` function. + @qml.transform + def dynamic_one_shot_partial( + tape: qml.tape.QuantumTape, + ) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: + nonlocal cpy_tape + cpy_tape = tape - `tgateset` is a list of target gateset for decomposition. - If provided, it will be combined with the default compiler ops for decomposition. - If not provided, `tkwargs` will be used as the keyword arguments for the - decomposition transform. This is to ensure compatibility with the existing - PennyLane decomposition transform as well as providing a fallback mechanism. + _validate_one_shot_measurements( + mcm_config, tape, user_specified_mcm_method, shot_vector, wires + ) - Args: - inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. - consts (list): The constants used in the JAXPR. - ncargs (list): Non-constant arguments for the JAXPR. - tgateset (list): A list of target gateset for decomposition. Defaults to None. - tkwargs (list): The keyword arguments of the decompose transform. Defaults to None. + if tape.batch_size is not None: + raise ValueError("mcm_method='one-shot' is not compatible with broadcasting") - Returns: - ClosedJaxpr: The decomposed JAXPR. - """ + aux_tapes = [init_auxiliary_tape(tape)] - # Disable the graph decomposition optimization - - # Why? Because for the compiler-specific decomposition we want to - # only decompose higher-level gates and templates that only have - # a single decomposition, and not do any further optimization - # based on the graph solution. - # Besides, the graph-based decomposition is not supported - # yet in from_plxpr for most gates and templates. - # TODO: Enable the graph-based decomposition - qml.decomposition.disable_graph() - - kwargs = ( - {"gate_set": set(COMPILER_OPS_FOR_DECOMPOSITION.keys()).union(tgateset)} - if tgateset - else tkwargs - ) - final_jaxpr = qml.transforms.decompose.plxpr_transform(inner_jaxpr, consts, (), kwargs, *ncargs) + def processing_fn(results): + return results - qml.decomposition.enable_graph() + return aux_tapes, processing_fn - return final_jaxpr + return dynamic_one_shot_partial(qnode) + single_shot_qnode = transform_to_single_shot(qnode) + single_shot_qnode = qml.set_shots(single_shot_qnode, shots=1) + if mcm_config is not None: + single_shot_qnode.execute_kwargs["postselect_mode"] = mcm_config.postselect_mode + single_shot_qnode.execute_kwargs["mcm_method"] = mcm_config.mcm_method + single_shot_qnode._dynamic_one_shot_called = True + total_shots = _get_total_shots(qnode) -def _collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): - """Collect and compile graph solutions for a given JAXPR. + def one_shot_wrapper(*args, **kwargs): + def wrap_single_shot_qnode(*_): + return single_shot_qnode(*args, **kwargs) - This function uses the DecompRuleInterpreter to evaluate - the input JAXPR and obtain a new JAXPR that incorporates - the graph-based decomposition solutions. + arg_vmap = jnp.empty((total_shots,), dtype=float) + results = catalyst.vmap(wrap_single_shot_qnode)(arg_vmap) + if isinstance(results[0], tuple) and len(results) == 1: + results = results[0] + has_mcm = any(isinstance(op, MidCircuitMeasure) for op in cpy_tape.operations) - This function doesn't modify the underlying quantum function - but rather constructs a new JAXPR with decomposition rules. + classical_return_indices = kwargs.pop("_classical_return_indices", [[]])[0] + num_mcm = kwargs.pop("_num_mcm_expected", [0])[0] + out_tree_expected = kwargs.pop("_out_tree_expected", [[]]) - Args: - inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. - consts (list): The constants used in the JAXPR. - tkwargs (list): The keyword arguments of the decompose transform. - ncargs (list): Non-constant arguments for the JAXPR. + # Split results into classical and measurement parts + classical_values, results = _extract_classical_and_measurement_results( + results, classical_return_indices + ) - Returns: - ClosedJaxpr: The decomposed JAXPR. - bool: A flag indicating whether the graph-based decomposition was successful. - """ - gds_interpreter = DecompRuleInterpreter(**tkwargs) - - def gds_wrapper(*args): - return gds_interpreter.eval(inner_jaxpr, consts, *args) - - graph_succeeded = True - - with warnings.catch_warnings(record=True) as captured_warnings: - warnings.simplefilter("always", UserWarning) - final_jaxpr = jax.make_jaxpr(gds_wrapper)(*ncargs) - - for w in captured_warnings: - warnings.showwarning(w.message, w.category, w.filename, w.lineno) - # TODO: use a custom warning class for this in PennyLane to remove this - # string matching and make it more robust. - if "The graph-based decomposition system is unable" in str(w.message): # pragma: no cover - graph_succeeded = False - warnings.warn( - "Falling back to the legacy decomposition system.", - UserWarning, + out = list(results) + + shot_vector = _get_shot_vector(qnode) + snapshots, out = _get_snapshot_results(mcm_config.mcm_method, cpy_tape, out) + + if has_mcm and len(cpy_tape.measurements) > 0: + out = parse_native_mid_circuit_measurements( + cpy_tape, results=results, postselect_mode="pad-invalid-samples" + ) + if len(cpy_tape.measurements) == 1: + out = (out,) + elif len(cpy_tape.measurements) > 0: + out = _process_terminal_measurements( + mcm_config.mcm_method, cpy_tape, out, snapshots, shot_vector ) - return final_jaxpr, graph_succeeded + ctx = OutputContext( + cpy_tape=cpy_tape, + classical_values=classical_values, + classical_return_indices=classical_return_indices, + out_tree_expected=out_tree_expected, + snapshots=snapshots, + shot_vector=shot_vector, + num_mcm=num_mcm, + ) + return _finalize_output(out, ctx) -def _get_operator_name(op): - """Get the name of a pennylane operator, handling wrapped operators. + return one_shot_wrapper - Note: Controlled and Adjoint ops aren't supported in `gate_set` - by PennyLane's DecompositionGraph; unit tests were added in PennyLane. - """ - if isinstance(op, str): - return op - # Return NoNameOp if the operator has no _primitive.name attribute. - # This is to avoid errors when we capture the program - # as we deal with such ops later in the decomposition graph. - return getattr(op._primitive, "name", "NoNameOp") +def _extract_passes(transform_program): + tape_transforms = [] + pass_pipeline = [] + for t in transform_program: + if t.pass_name: + pass_pipeline.append(Pass(t.pass_name, *t.args, **t.kwargs)) + else: + tape_transforms.append(t) + return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline) \ No newline at end of file diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 1227b7f1b..ab52e9a5c 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -308,6 +308,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.ctx.module_context = self.old_module_context +def _lowered_options(kwargs): + lowered_options = {} + for option, value in kwargs.items(): + mlir_option = str(option).replace("_", "-") + lowered_options[mlir_option] = get_mlir_attribute_from_pyval(value) + return lowered_options + def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipeline): """Generate a transform module embedded in the current module and schedule the transformations in pipeline""" @@ -350,11 +357,16 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin with ir.InsertionPoint(bb_named_sequence): target = bb_named_sequence.arguments[0] for _pass in pipeline: - options = _pass.get_options() + if isinstance(_pass, qml.transforms.core.TransformContainer): + options = _lowered_options(_pass.kwargs) + name = _pass.pass_name + else: + options = _pass.get_options() + name = _pass.name apply_registered_pass_op = ApplyRegisteredPassOp( result=transform_mod_type, target=target, - pass_name=_pass.name, + pass_name=name, options=options, dynamic_options={}, ) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index edac9890a..fab0d2e09 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -43,7 +43,7 @@ from catalyst.jax_primitives import quantum_kernel_p from catalyst.jax_tracer import Function, trace_quantum_function from catalyst.logging import debug_logger -from catalyst.passes.pass_api import dictionary_to_list_of_passes +from catalyst.passes.pass_api import dictionary_to_list_of_passes, Pass from catalyst.tracing.contexts import EvaluationContext from catalyst.tracing.type_signatures import filter_static_args from catalyst.utils.exceptions import CompileError @@ -283,18 +283,22 @@ def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) + new_transform_program, new_pipeline = _extract_passes(self.transform_program) + # Update the qnode with peephole pipeline - pass_pipeline = kwargs.pop("pass_pipeline", []) + pass_pipeline = kwargs.pop("pass_pipeline", []) + new_pipeline pass_pipeline = dictionary_to_list_of_passes(pass_pipeline) + new_qnode = copy(self) + new_qnode._transform_program = new_transform_program # pylint: disable=protected-access # Mid-circuit measurement configuration/execution - fn_result = configure_mcm_and_try_one_shot(self, args, kwargs) + fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs) # If the qnode is failed to execute as one-shot, fn_result will be None if fn_result is not None: return fn_result - new_device = copy(self.device) + new_device = copy(new_qnode.device) qjit_device = QJITDevice(new_device) static_argnums = kwargs.pop("static_argnums", ()) @@ -305,11 +309,11 @@ def __call__(self, *args, **kwargs): def _eval_quantum(*args, **kwargs): trace_result = trace_quantum_function( - self.func, + new_qnode.func, qjit_device, args, kwargs, - self, + new_qnode, static_argnums, debug_info, ) @@ -649,3 +653,14 @@ def wrap_single_shot_qnode(*_): return _finalize_output(out, ctx) return one_shot_wrapper + + +def _extract_passes(transform_program): + tape_transforms = [] + pass_pipeline = [] + for t in transform_program: + if t.pass_name: + pass_pipeline.append(Pass(t.pass_name, *t.args, **t.kwargs)) + else: + tape_transforms.append(t) + return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline) \ No newline at end of file