From a82878fd6651465a67488bb2cfb821a8a30558e6 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Mon, 4 Aug 2025 22:59:12 -0700 Subject: [PATCH 1/7] Add a new transformer that performs random pauli insertion --- cirq-core/cirq/transformers/__init__.py | 3 + .../cirq/transformers/pauli_insertion.py | 120 ++++++++++++++++++ .../cirq/transformers/pauli_insertion_test.py | 60 +++++++++ 3 files changed, 183 insertions(+) create mode 100644 cirq-core/cirq/transformers/pauli_insertion.py create mode 100644 cirq-core/cirq/transformers/pauli_insertion_test.py diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index e3d1c9a0d35..ce718222d6c 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -160,3 +160,6 @@ from cirq.transformers.insertion_sort import ( insertion_sort_transformer as insertion_sort_transformer, ) + + +from cirq.transformers.pauli_insertion import PauliInsertionTransformer as PauliInsertionTransformer diff --git a/cirq-core/cirq/transformers/pauli_insertion.py b/cirq-core/cirq/transformers/pauli_insertion.py new file mode 100644 index 00000000000..a7a833f7d70 --- /dev/null +++ b/cirq-core/cirq/transformers/pauli_insertion.py @@ -0,0 +1,120 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A pauli insertion transformer.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING + +import numpy as np + +from cirq import circuits, ops +from cirq.transformers import transformer_api + +_PAULIS = [ops.I, ops.X, ops.Y, ops.Z] + + +def _is_target( + op: ops.Operation, + target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate | ops.Operation], +): + if inspect.isclass(target): + if issubclass(target, ops.Operation): + return isinstance(op, target) + if not hasattr(op, 'gate'): + return False + return isinstance(op.gate, target) + if isinstance(target, ops.Gate): + if not hasattr(op, 'gate') or op.gate is None: + return False + return op.gate == target + return op in target + + +@transformer_api.transformer +class PauliInsertionTransformer: + r"""Creates a pauli insertion transformer. + + A pauli insertion operation samples paulis from $\{I, X, Y, Z\}^2$ with the given + probabilities and adds it before the target 2Q gate/operation. This procedure is commonly + used in zero noise extrapolation (ZNE), see appendix D of https://arxiv.org/abs/2503.20870. + """ + + def __init__( + self, + target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate | ops.Operation], + probabilities: np.ndarray | None = None, + ): + """Creates a pauli insertion transformer that samples 2Q paulis with the given probabilities. + + Args: + target: The target gate, gatefamily, gateset, or type (e.g. PauliSumExponential). + probabilities: Optional ndarray representing the probabilities of sampling 2Q paulis. + The order of the paulis is IXYZ. If None, assume uniform distribution. + Returns: + A gauge transformer. + """ + if probabilities is None: + probabilities = np.ones((4, 4)) / 16 + probabilities = np.asarray(probabilities) + assert probabilities.shape == (4, 4) + assert np.isclose(probabilities.sum(), 1) + + self.target = target + self._flat_probs = probabilities.reshape(-1) + + def __call__( + self, + circuit: circuits.AbstractCircuit, + *, + rng_or_seed: np.random.Generator | int | None = None, + context: transformer_api.TransformerContext | None = None, + ): + context = ( + context + if isinstance(context, transformer_api.TransformerContext) + else transformer_api.TransformerContext() + ) + rng = ( + rng_or_seed + if isinstance(rng_or_seed, np.random.Generator) + else np.random.default_rng(rng_or_seed) + ) + + if context.deep: + raise ValueError(f"this transformer doesn't support deep {context=}") + + tags_to_ignore = frozenset(context.tags_to_ignore) + new_circuit: list[circuits.Moment] = [] + for moment in circuit: + if any(tag in tags_to_ignore for tag in moment.tags): + continue + new_moment = [] + for op in moment: + if any(tag in tags_to_ignore for tag in op.tags): + continue + if not _is_target(op, self.target): + continue + pair = np.unravel_index(rng.choice(16, p=self._flat_probs), (4, 4)) + for pauli_index, q in zip(pair, op.qubits): + if new_circuit and (q not in new_circuit[-1].qubits): + new_circuit[-1] += _PAULIS[pauli_index](q) + else: + new_moment.append(_PAULIS[pauli_index](q)) + if new_moment: + new_circuit.append(circuits.Moment(new_moment)) + new_circuit.append(moment) + return circuits.Circuit.from_moments(*new_circuit) diff --git a/cirq-core/cirq/transformers/pauli_insertion_test.py b/cirq-core/cirq/transformers/pauli_insertion_test.py new file mode 100644 index 00000000000..22d5832a5e7 --- /dev/null +++ b/cirq-core/cirq/transformers/pauli_insertion_test.py @@ -0,0 +1,60 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import cirq + +_PAULIS = [cirq.I, cirq.X, cirq.Y, cirq.Z] + + +def _random_probs(n: int, seed: int | None = None): + rng = np.random.default_rng(seed) + for _ in range(n): + probs = rng.random((4, 4)) + probs /= probs.sum() + yield probs + + +@pytest.mark.parametrize('probs', _random_probs(3, 0)) +def test_pauli_insertion_with_probabilities(probs): + c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate, probs) + count = np.zeros((4, 4)) + for _ in range(100): + nc = transformer(c) + assert len(nc) == 2 + u, v = nc[0] + i = _PAULIS.index(u.gate) + j = _PAULIS.index(v.gate) + count[i, j] += 1 + count = count / count.sum() + np.testing.assert_allclose(count, probs, atol=0.1) + + +@pytest.mark.parametrize('probs', _random_probs(3, 0)) +def test_pauli_insertion_with_probabilities_doesnot_create_moment(probs): + c = cirq.Circuit.from_moments([], [cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324]) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate, probs) + count = np.zeros((4, 4)) + for _ in range(100): + nc = transformer(c) + assert len(nc) == 2 + u, v = nc[0] + i = _PAULIS.index(u.gate) + j = _PAULIS.index(v.gate) + count[i, j] += 1 + count = count / count.sum() + np.testing.assert_allclose(count, probs, atol=0.1) From 83be0e677ebc2b98b96d748c1bc27c9004d3e7c3 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Mon, 4 Aug 2025 23:02:37 -0700 Subject: [PATCH 2/7] lint --- cirq-core/cirq/transformers/pauli_insertion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cirq-core/cirq/transformers/pauli_insertion.py b/cirq-core/cirq/transformers/pauli_insertion.py index a7a833f7d70..9c51a3e9fee 100644 --- a/cirq-core/cirq/transformers/pauli_insertion.py +++ b/cirq-core/cirq/transformers/pauli_insertion.py @@ -17,7 +17,6 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING import numpy as np @@ -58,7 +57,7 @@ def __init__( target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate | ops.Operation], probabilities: np.ndarray | None = None, ): - """Creates a pauli insertion transformer that samples 2Q paulis with the given probabilities. + """Makes a pauli insertion transformer that samples 2Q paulis with the given probabilities. Args: target: The target gate, gatefamily, gateset, or type (e.g. PauliSumExponential). From 2c46d19e447c8c38610df3b373d370e0520b460e Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Mon, 4 Aug 2025 23:06:11 -0700 Subject: [PATCH 3/7] nit --- cirq-core/cirq/transformers/pauli_insertion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cirq-core/cirq/transformers/pauli_insertion.py b/cirq-core/cirq/transformers/pauli_insertion.py index 9c51a3e9fee..89e24e37ce5 100644 --- a/cirq-core/cirq/transformers/pauli_insertion.py +++ b/cirq-core/cirq/transformers/pauli_insertion.py @@ -100,6 +100,7 @@ def __call__( new_circuit: list[circuits.Moment] = [] for moment in circuit: if any(tag in tags_to_ignore for tag in moment.tags): + new_circuit.append(moment) continue new_moment = [] for op in moment: From 9c2064f927804e843984dc5d96961c01cd0f18b7 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Tue, 5 Aug 2025 08:45:46 -0700 Subject: [PATCH 4/7] mypy & coverage --- .../cirq/transformers/pauli_insertion.py | 33 ++++++----------- .../cirq/transformers/pauli_insertion_test.py | 37 +++++++++++++++++-- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/cirq-core/cirq/transformers/pauli_insertion.py b/cirq-core/cirq/transformers/pauli_insertion.py index 89e24e37ce5..1c73beca022 100644 --- a/cirq-core/cirq/transformers/pauli_insertion.py +++ b/cirq-core/cirq/transformers/pauli_insertion.py @@ -23,24 +23,7 @@ from cirq import circuits, ops from cirq.transformers import transformer_api -_PAULIS = [ops.I, ops.X, ops.Y, ops.Z] - - -def _is_target( - op: ops.Operation, - target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate | ops.Operation], -): - if inspect.isclass(target): - if issubclass(target, ops.Operation): - return isinstance(op, target) - if not hasattr(op, 'gate'): - return False - return isinstance(op.gate, target) - if isinstance(target, ops.Gate): - if not hasattr(op, 'gate') or op.gate is None: - return False - return op.gate == target - return op in target +_PAULIS: tuple[ops.Gate] = (ops.I, ops.X, ops.Y, ops.Z) # type: ignore[has-type] @transformer_api.transformer @@ -54,13 +37,13 @@ class PauliInsertionTransformer: def __init__( self, - target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate | ops.Operation], + target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate], probabilities: np.ndarray | None = None, ): """Makes a pauli insertion transformer that samples 2Q paulis with the given probabilities. Args: - target: The target gate, gatefamily, gateset, or type (e.g. PauliSumExponential). + target: The target gate, gatefamily, gateset, or type (e.g. ZZPowGAte). probabilities: Optional ndarray representing the probabilities of sampling 2Q paulis. The order of the paulis is IXYZ. If None, assume uniform distribution. Returns: @@ -72,7 +55,13 @@ def __init__( assert probabilities.shape == (4, 4) assert np.isclose(probabilities.sum(), 1) - self.target = target + if inspect.isclass(target): + self.target = ops.GateFamily(target) + elif isinstance(target, ops.Gate): + self.target = ops.Gateset(target) + else: + assert isinstance(target, (ops.Gateset, ops.GateFamily)) + self.target = target self._flat_probs = probabilities.reshape(-1) def __call__( @@ -106,7 +95,7 @@ def __call__( for op in moment: if any(tag in tags_to_ignore for tag in op.tags): continue - if not _is_target(op, self.target): + if op not in self.target: continue pair = np.unravel_index(rng.choice(16, p=self._flat_probs), (4, 4)) for pauli_index, q in zip(pair, op.qubits): diff --git a/cirq-core/cirq/transformers/pauli_insertion_test.py b/cirq-core/cirq/transformers/pauli_insertion_test.py index 22d5832a5e7..276305eb4db 100644 --- a/cirq-core/cirq/transformers/pauli_insertion_test.py +++ b/cirq-core/cirq/transformers/pauli_insertion_test.py @@ -29,12 +29,17 @@ def _random_probs(n: int, seed: int | None = None): @pytest.mark.parametrize('probs', _random_probs(3, 0)) -def test_pauli_insertion_with_probabilities(probs): +@pytest.mark.parametrize( + 'target', + [cirq.ZZPowGate, cirq.ZZ**0.324, cirq.Gateset(cirq.ZZ**0.324), cirq.GateFamily(cirq.ZZ**0.324)], +) +def test_pauli_insertion_with_probabilities(probs, target): c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324) - transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate, probs) + transformer = cirq.transformers.PauliInsertionTransformer(target, probs) count = np.zeros((4, 4)) + rng = np.random.default_rng(0) for _ in range(100): - nc = transformer(c) + nc = transformer(c, rng_or_seed=rng) assert len(nc) == 2 u, v = nc[0] i = _PAULIS.index(u.gate) @@ -49,8 +54,9 @@ def test_pauli_insertion_with_probabilities_doesnot_create_moment(probs): c = cirq.Circuit.from_moments([], [cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324]) transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate, probs) count = np.zeros((4, 4)) + rng = np.random.default_rng(0) for _ in range(100): - nc = transformer(c) + nc = transformer(c, rng_or_seed=rng) assert len(nc) == 2 u, v = nc[0] i = _PAULIS.index(u.gate) @@ -58,3 +64,26 @@ def test_pauli_insertion_with_probabilities_doesnot_create_moment(probs): count[i, j] += 1 count = count / count.sum() np.testing.assert_allclose(count, probs, atol=0.1) + + +def test_invalid_context_raises(): + c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate) + with pytest.raises(ValueError): + _ = transformer(c, context=cirq.TransformerContext(deep=True)) + + +def test_transformer_ignores_tagged_ops(): + op = cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324 + c = cirq.Circuit(op.with_tags('ignore')) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate) + + assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c + + +def test_transformer_ignores_tagged_moments(): + op = cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324 + c = cirq.Circuit(cirq.Moment(op).with_tags('ignore')) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate) + + assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c From 0b6a9a4517416361157ed0f4cb6b960aaba1ea88 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Thu, 7 Aug 2025 11:02:37 -0700 Subject: [PATCH 5/7] address comments --- .../cirq/transformers/pauli_insertion.py | 51 ++++++++++++++----- .../cirq/transformers/pauli_insertion_test.py | 15 ++++++ 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/cirq-core/cirq/transformers/pauli_insertion.py b/cirq-core/cirq/transformers/pauli_insertion.py index 1c73beca022..63b06de4775 100644 --- a/cirq-core/cirq/transformers/pauli_insertion.py +++ b/cirq-core/cirq/transformers/pauli_insertion.py @@ -17,6 +17,7 @@ from __future__ import annotations import inspect +from collections.abc import Mapping import numpy as np @@ -38,22 +39,30 @@ class PauliInsertionTransformer: def __init__( self, target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate], - probabilities: np.ndarray | None = None, + probabilities: np.ndarray | Mapping[tuple[ops.Qid, ops.Qid], np.ndarray] | None = None, ): """Makes a pauli insertion transformer that samples 2Q paulis with the given probabilities. Args: target: The target gate, gatefamily, gateset, or type (e.g. ZZPowGAte). - probabilities: Optional ndarray representing the probabilities of sampling 2Q paulis. - The order of the paulis is IXYZ. If None, assume uniform distribution. - Returns: - A gauge transformer. + probabilities: Optional ndarray or mapping[qubit-pair, nndarray] representing the + probabilities of sampling 2Q paulis. The order of the paulis is IXYZ. + If at operation `op` a pair (i, j) is sampled then _PAULIS[i] is applied + to op.qubits[0] and _PAULIS[j] is applied to op.qubits[1]. + If None, assume uniform distribution. """ if probabilities is None: probabilities = np.ones((4, 4)) / 16 - probabilities = np.asarray(probabilities) - assert probabilities.shape == (4, 4) - assert np.isclose(probabilities.sum(), 1) + elif isinstance(probabilities, dict): + probabilities = {k: np.asarray(v) for k, v in probabilities.items()} + for probs in probabilities.values(): + assert np.isclose(probs.sum(), 1) + assert probs.shape == (4, 4) + else: + probabilities = np.asarray(probabilities) + assert np.isclose(probabilities.sum(), 1) + assert probabilities.shape == (4, 4) + self.probabilities = probabilities if inspect.isclass(target): self.target = ops.GateFamily(target) @@ -62,7 +71,21 @@ def __init__( else: assert isinstance(target, (ops.Gateset, ops.GateFamily)) self.target = target - self._flat_probs = probabilities.reshape(-1) + + def _is_target(self, op: ops.Operation) -> bool: + if isinstance(self.probabilities, dict) and op.qubits not in self.probabilities: + return False + return op in self.target + + def _sample( + self, qubits: tuple[ops.Qid, ops.Qid], rng: np.random.Generator + ) -> tuple[ops.Gate, ops.Gate]: + if isinstance(self.probabilities, dict): + flat_probs = self.probabilities[qubits].reshape(-1) + else: + flat_probs = self.probabilities.reshape(-1) + i, j = np.unravel_index(rng.choice(16, p=flat_probs), (4, 4)) + return _PAULIS[i], _PAULIS[j] def __call__( self, @@ -95,14 +118,14 @@ def __call__( for op in moment: if any(tag in tags_to_ignore for tag in op.tags): continue - if op not in self.target: + if not self._is_target(op): continue - pair = np.unravel_index(rng.choice(16, p=self._flat_probs), (4, 4)) - for pauli_index, q in zip(pair, op.qubits): + pair = self._sample(op.qubits, rng) + for pauli, q in zip(pair, op.qubits): if new_circuit and (q not in new_circuit[-1].qubits): - new_circuit[-1] += _PAULIS[pauli_index](q) + new_circuit[-1] += pauli(q) else: - new_moment.append(_PAULIS[pauli_index](q)) + new_moment.append(pauli(q)) if new_moment: new_circuit.append(circuits.Moment(new_moment)) new_circuit.append(moment) diff --git a/cirq-core/cirq/transformers/pauli_insertion_test.py b/cirq-core/cirq/transformers/pauli_insertion_test.py index 276305eb4db..130c95eed9a 100644 --- a/cirq-core/cirq/transformers/pauli_insertion_test.py +++ b/cirq-core/cirq/transformers/pauli_insertion_test.py @@ -87,3 +87,18 @@ def test_transformer_ignores_tagged_moments(): transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate) assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c + + +def test_transformer_ignores_with_probs_map(): + qs = tuple(cirq.LineQubit.range(3)) + op = cirq.ZZ(*qs[:2]) ** 0.324 + c = cirq.Circuit(cirq.Moment(op)) + transformer = cirq.transformers.PauliInsertionTransformer( + cirq.ZZPowGate, {qs[1:]: np.ones((4, 4)) / 16} + ) + + assert transformer(c) == c # qubits are not in target + + c = cirq.Circuit(cirq.Moment(op.with_qubits(*qs[1:]))) + nc = transformer(c) + assert len(nc) == 2 From a46e59941be123761a7ec39ced97db608f157fc4 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Thu, 7 Aug 2025 11:07:40 -0700 Subject: [PATCH 6/7] fix types --- cirq-core/cirq/transformers/pauli_insertion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/transformers/pauli_insertion.py b/cirq-core/cirq/transformers/pauli_insertion.py index 63b06de4775..bad62d4aeed 100644 --- a/cirq-core/cirq/transformers/pauli_insertion.py +++ b/cirq-core/cirq/transformers/pauli_insertion.py @@ -24,7 +24,7 @@ from cirq import circuits, ops from cirq.transformers import transformer_api -_PAULIS: tuple[ops.Gate] = (ops.I, ops.X, ops.Y, ops.Z) # type: ignore[has-type] +_PAULIS: tuple[ops.Gate, ops.Gate, ops.Gate, ops.Gate] = (ops.I, ops.X, ops.Y, ops.Z) # type: ignore[has-type] @transformer_api.transformer @@ -65,7 +65,7 @@ def __init__( self.probabilities = probabilities if inspect.isclass(target): - self.target = ops.GateFamily(target) + self.target: ops.GateFamily | ops.Gateset = ops.GateFamily(target) elif isinstance(target, ops.Gate): self.target = ops.Gateset(target) else: @@ -78,7 +78,7 @@ def _is_target(self, op: ops.Operation) -> bool: return op in self.target def _sample( - self, qubits: tuple[ops.Qid, ops.Qid], rng: np.random.Generator + self, qubits: tuple[ops.Qid, ...], rng: np.random.Generator ) -> tuple[ops.Gate, ops.Gate]: if isinstance(self.probabilities, dict): flat_probs = self.probabilities[qubits].reshape(-1) From 4eb0bddfb7d1686c41b0c4dcb7b32b7db1710518 Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Thu, 7 Aug 2025 12:18:44 -0700 Subject: [PATCH 7/7] nit --- cirq-core/cirq/transformers/pauli_insertion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cirq-core/cirq/transformers/pauli_insertion.py b/cirq-core/cirq/transformers/pauli_insertion.py index bad62d4aeed..df8188dc127 100644 --- a/cirq-core/cirq/transformers/pauli_insertion.py +++ b/cirq-core/cirq/transformers/pauli_insertion.py @@ -81,6 +81,7 @@ def _sample( self, qubits: tuple[ops.Qid, ...], rng: np.random.Generator ) -> tuple[ops.Gate, ops.Gate]: if isinstance(self.probabilities, dict): + assert len(qubits) == 2 flat_probs = self.probabilities[qubits].reshape(-1) else: flat_probs = self.probabilities.reshape(-1)