From 93986f9ca2d867749decff82053aed68b9ca88e6 Mon Sep 17 00:00:00 2001 From: Revanth Gundala Date: Thu, 17 Apr 2025 15:59:53 -0700 Subject: [PATCH 1/6] Improve controlled gate decomposition logic and add targeted unit tests --- cirq-core/cirq/ops/controlled_gate.py | 15 ++++++---- cirq-core/cirq/ops/controlled_gate_test.py | 33 ++++++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index d194e645816..3d186045121 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -177,18 +177,21 @@ def _decompose_with_context_( protocols.unitary(self.sub_gate), control_qubits, qubits[-1] ) return invert_ops + decomposed_ops + invert_ops - if isinstance(self.sub_gate, gp.GlobalPhaseGate): - # A controlled global phase is a diagonal gate, where each active control value index - # is set equal to the phase angle. + # Handle global phase gates by checking if the unitary is a 1x1 matrix with a global phase + unitary = protocols.unitary(self.sub_gate, default=None) + if ( + unitary is not None + and unitary.shape == (1, 1) + and np.isclose(np.abs(unitary[0, 0]), 1) + ): shape = self.control_qid_shape if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}: - # Could work in theory, but DiagonalGate decompose does not support them. return NotImplemented - angle = np.angle(complex(self.sub_gate.coefficient)) + angle = np.angle(unitary[0, 0]) rads = np.zeros(shape=shape) for hot in self.control_values.expand(): rads[hot] = angle - return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) + return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) if isinstance(self.sub_gate, common_gates.CZPowGate): z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent) num_controls = self.num_controls() + 1 diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index ebff6b9c709..83734b616bc 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -737,3 +737,36 @@ def test_controlled_mixture(): c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1) assert cirq.has_mixture(c_yes) assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) + + +@pytest.mark.parametrize( + "phase,num_controls,control_values", + [ + (0, 1, (1,)), + (np.pi / 4, 1, (1,)), + (np.pi / 2, 2, (1, 0)), + (np.pi, 2, (0, 1)), + (2 * np.pi, 3, (1, 1, 1)), + ] +) +def test_controlled_global_phase_gate_unitary(phase, num_controls, control_values): + coefficient = np.exp(1j * phase) + sub_gate = cirq.GlobalPhaseGate(coefficient=coefficient) + controlled_gate = cirq.ControlledGate( + sub_gate=sub_gate, num_controls=num_controls, control_values=control_values + ) + sub_unitary = cirq.unitary(sub_gate, default=None) + assert sub_unitary is not None + assert sub_unitary.shape == (1, 1) + assert np.isclose(np.abs(sub_unitary[0, 0]), 1.0) + assert not cirq.is_parameterized(sub_gate) + + dim = 2**num_controls + diag_angles = np.zeros(dim) + control_index = sum(val * (2**i) for i, val in enumerate(reversed(control_values))) + diag_angles[control_index] = np.angle(coefficient) + expected_gate = cirq.DiagonalGate(diag_angles_radians=diag_angles) + qids = cirq.LineQubit.range(num_controls) + actual_unitary = cirq.unitary(controlled_gate) + expected_unitary = cirq.unitary(expected_gate.on(*qids)) + cirq.testing.assert_allclose_up_to_global_phase(actual_unitary, expected_unitary, atol=1e-8) From bf8cff7961ed437aae9a2626f4446ab3753bd863 Mon Sep 17 00:00:00 2001 From: Revanth Gundala Date: Fri, 18 Apr 2025 13:49:51 -0700 Subject: [PATCH 2/6] Refactored branch + removed unit test --- cirq-core/cirq/ops/controlled_gate.py | 47 +++++++++------------- cirq-core/cirq/ops/controlled_gate_test.py | 35 +--------------- 2 files changed, 21 insertions(+), 61 deletions(-) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 3d186045121..ac651c82fd2 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -163,35 +163,28 @@ def _decompose_with_context_( control_qubits = list(qubits[: self.num_controls()]) if ( protocols.has_unitary(self.sub_gate) - and protocols.num_qubits(self.sub_gate) == 1 and self._qid_shape_() == (2,) * len(self._qid_shape_()) - and isinstance(self.control_values, cv.ProductOfSums) ): - invert_ops: List[cirq.Operation] = [] - for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): - if set(cvals) == {0}: - invert_ops.append(common_gates.X(cqbit)) - elif set(cvals) == {0, 1}: - control_qubits.remove(cqbit) - decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation( - protocols.unitary(self.sub_gate), control_qubits, qubits[-1] - ) - return invert_ops + decomposed_ops + invert_ops - # Handle global phase gates by checking if the unitary is a 1x1 matrix with a global phase - unitary = protocols.unitary(self.sub_gate, default=None) - if ( - unitary is not None - and unitary.shape == (1, 1) - and np.isclose(np.abs(unitary[0, 0]), 1) - ): - shape = self.control_qid_shape - if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}: - return NotImplemented - angle = np.angle(unitary[0, 0]) - rads = np.zeros(shape=shape) - for hot in self.control_values.expand(): - rads[hot] = angle - return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) + n_qubits = protocols.num_qubits(self.sub_gate) + # Case 1: Multi-controlled single-qubit gate decomposition + if(n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums)): + invert_ops: List[cirq.Operation] = [] + for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): + if set(cvals) == {0}: + invert_ops.append(common_gates.X(cqbit)) + elif set(cvals) == {0, 1}: + control_qubits.remove(cqbit) + decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation( + protocols.unitary(self.sub_gate), control_qubits, qubits[-1] + ) + return invert_ops + decomposed_ops + invert_ops + # Case 2: Global Phase (1x1 Matrix) + if(n_qubits == 0): + angle = np.angle(protocols.unitary(self.sub_gate)[0, 0]) + rads = np.zeros(shape=self.control_qid_shape) + for hot in self.control_values.expand(): + rads[hot] = angle + return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) if isinstance(self.sub_gate, common_gates.CZPowGate): z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent) num_controls = self.num_controls() + 1 diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index 83734b616bc..fb0a1f2d4ec 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -736,37 +736,4 @@ def test_str(): def test_controlled_mixture(): c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1) assert cirq.has_mixture(c_yes) - assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) - - -@pytest.mark.parametrize( - "phase,num_controls,control_values", - [ - (0, 1, (1,)), - (np.pi / 4, 1, (1,)), - (np.pi / 2, 2, (1, 0)), - (np.pi, 2, (0, 1)), - (2 * np.pi, 3, (1, 1, 1)), - ] -) -def test_controlled_global_phase_gate_unitary(phase, num_controls, control_values): - coefficient = np.exp(1j * phase) - sub_gate = cirq.GlobalPhaseGate(coefficient=coefficient) - controlled_gate = cirq.ControlledGate( - sub_gate=sub_gate, num_controls=num_controls, control_values=control_values - ) - sub_unitary = cirq.unitary(sub_gate, default=None) - assert sub_unitary is not None - assert sub_unitary.shape == (1, 1) - assert np.isclose(np.abs(sub_unitary[0, 0]), 1.0) - assert not cirq.is_parameterized(sub_gate) - - dim = 2**num_controls - diag_angles = np.zeros(dim) - control_index = sum(val * (2**i) for i, val in enumerate(reversed(control_values))) - diag_angles[control_index] = np.angle(coefficient) - expected_gate = cirq.DiagonalGate(diag_angles_radians=diag_angles) - qids = cirq.LineQubit.range(num_controls) - actual_unitary = cirq.unitary(controlled_gate) - expected_unitary = cirq.unitary(expected_gate.on(*qids)) - cirq.testing.assert_allclose_up_to_global_phase(actual_unitary, expected_unitary, atol=1e-8) + assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) \ No newline at end of file From 77188f53333e11315a93dc6e23a6aa6b4e3ebc19 Mon Sep 17 00:00:00 2001 From: Revanth Gundala Date: Fri, 18 Apr 2025 14:14:45 -0700 Subject: [PATCH 3/6] Linting update + tests --- cirq-core/cirq/ops/controlled_gate.py | 9 ++++----- cirq-core/cirq/ops/controlled_gate_test.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index ac651c82fd2..4239aaa751e 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -161,13 +161,12 @@ def _decompose_with_context_( self, qubits: Tuple[cirq.Qid, ...], context: Optional[cirq.DecompositionContext] = None ) -> Union[None, NotImplementedType, cirq.OP_TREE]: control_qubits = list(qubits[: self.num_controls()]) - if ( - protocols.has_unitary(self.sub_gate) - and self._qid_shape_() == (2,) * len(self._qid_shape_()) + if protocols.has_unitary(self.sub_gate) and self._qid_shape_() == (2,) * len( + self._qid_shape_() ): n_qubits = protocols.num_qubits(self.sub_gate) # Case 1: Multi-controlled single-qubit gate decomposition - if(n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums)): + if n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums): invert_ops: List[cirq.Operation] = [] for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): if set(cvals) == {0}: @@ -179,7 +178,7 @@ def _decompose_with_context_( ) return invert_ops + decomposed_ops + invert_ops # Case 2: Global Phase (1x1 Matrix) - if(n_qubits == 0): + if n_qubits == 0: angle = np.angle(protocols.unitary(self.sub_gate)[0, 0]) rads = np.zeros(shape=self.control_qid_shape) for hot in self.control_values.expand(): diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index fb0a1f2d4ec..ebff6b9c709 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -736,4 +736,4 @@ def test_str(): def test_controlled_mixture(): c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1) assert cirq.has_mixture(c_yes) - assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) \ No newline at end of file + assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) From a55ebe82ef0ded3d8b974f1d982055d12e7b7798 Mon Sep 17 00:00:00 2001 From: Revanth Gundala Date: Fri, 18 Apr 2025 18:04:43 -0700 Subject: [PATCH 4/6] Refactor + Unit Test --- cirq-core/cirq/ops/controlled_gate.py | 20 +++++------ cirq-core/cirq/ops/controlled_gate_test.py | 41 +++++++++++++++++++++- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 4239aaa751e..c5dc8ed975b 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -161,11 +161,16 @@ def _decompose_with_context_( self, qubits: Tuple[cirq.Qid, ...], context: Optional[cirq.DecompositionContext] = None ) -> Union[None, NotImplementedType, cirq.OP_TREE]: control_qubits = list(qubits[: self.num_controls()]) - if protocols.has_unitary(self.sub_gate) and self._qid_shape_() == (2,) * len( - self._qid_shape_() - ): + if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits): n_qubits = protocols.num_qubits(self.sub_gate) - # Case 1: Multi-controlled single-qubit gate decomposition + # Case 1: Global Phase (1x1 Matrix) + if n_qubits == 0: + angle = np.angle(protocols.unitary(self.sub_gate)[0, 0]) + rads = np.zeros(shape=self.control_qid_shape) + for hot in self.control_values.expand(): + rads[hot] = angle + return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) + # Case 2: Multi-controlled single-qubit gate decomposition if n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums): invert_ops: List[cirq.Operation] = [] for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): @@ -177,13 +182,6 @@ def _decompose_with_context_( protocols.unitary(self.sub_gate), control_qubits, qubits[-1] ) return invert_ops + decomposed_ops + invert_ops - # Case 2: Global Phase (1x1 Matrix) - if n_qubits == 0: - angle = np.angle(protocols.unitary(self.sub_gate)[0, 0]) - rads = np.zeros(shape=self.control_qid_shape) - for hot in self.control_values.expand(): - rads[hot] = angle - return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) if isinstance(self.sub_gate, common_gates.CZPowGate): z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent) num_controls = self.num_controls() + 1 diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index ebff6b9c709..3199dcfd95b 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -238,7 +238,6 @@ def test_eq(): eq.add_equality_group(CCH) eq.add_equality_group(cirq.ControlledGate(cirq.H)) eq.add_equality_group(cirq.ControlledGate(cirq.X)) - eq.add_equality_group(cirq.X) eq.add_equality_group( cirq.ControlledGate(cirq.H, control_values=[1, (0, 2)], control_qid_shape=[2, 3]), cirq.ControlledGate(cirq.H, control_values=(1, [0, 2]), control_qid_shape=(2, 3)), @@ -737,3 +736,43 @@ def test_controlled_mixture(): c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1) assert cirq.has_mixture(c_yes) assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) + +@pytest.mark.parametrize( + 'num_controls, angle, control_values', + [ + (1, np.pi / 4, ((1,),)), + (3, -np.pi / 2, ((1,), (1,), (1,))), + (2, 0.0, ((1,), (1,))), + (2, np.pi / 5, ((0,), (0,))), + (3, np.pi, ((1,), (0,), (1,))), + (4, -np.pi / 3, ((0,), (1,), (1,), (0,))), + ], +) +def test_controlled_global_phase_matrix_gate_decomposition_consistency( + num_controls, angle, control_values +): + all_qubits = cirq.LineQubit.range(num_controls) + control_values = cirq.ops.control_values.ProductOfSums(control_values) + control_qid_shape = (2,) * num_controls + phase_value = np.exp(1j * angle) + + cg_global = cirq.ControlledGate( + sub_gate=cirq.GlobalPhaseGate(phase_value), + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape + ) + + cg_matrix = cirq.ControlledGate( + sub_gate=cirq.MatrixGate(np.array([[phase_value]])), + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape + ) + + decomp_global = cirq.decompose_once(cg_global(*all_qubits)) + decomp_matrix = cirq.decompose_once(cg_matrix(*all_qubits)) + + assert decomp_global is not None + assert decomp_matrix is not None + assert decomp_global == decomp_matrix From 2a48fc609212f34f3527d90a566898fe4eca423c Mon Sep 17 00:00:00 2001 From: Revanth Gundala Date: Fri, 18 Apr 2025 18:08:42 -0700 Subject: [PATCH 5/6] Unit test + refactor --- cirq-core/cirq/ops/controlled_gate_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index 3199dcfd95b..cf026960bcf 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -238,6 +238,7 @@ def test_eq(): eq.add_equality_group(CCH) eq.add_equality_group(cirq.ControlledGate(cirq.H)) eq.add_equality_group(cirq.ControlledGate(cirq.X)) + eq.add_equality_group(cirq.X) eq.add_equality_group( cirq.ControlledGate(cirq.H, control_values=[1, (0, 2)], control_qid_shape=[2, 3]), cirq.ControlledGate(cirq.H, control_values=(1, [0, 2]), control_qid_shape=(2, 3)), From bbb073a75599f7614581803d85b7c0615d730984 Mon Sep 17 00:00:00 2001 From: Revanth Gundala Date: Fri, 18 Apr 2025 18:55:14 -0700 Subject: [PATCH 6/6] Fixed linting --- cirq-core/cirq/ops/controlled_gate_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index cf026960bcf..da79b5a74dc 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -738,6 +738,7 @@ def test_controlled_mixture(): assert cirq.has_mixture(c_yes) assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) + @pytest.mark.parametrize( 'num_controls, angle, control_values', [ @@ -761,19 +762,19 @@ def test_controlled_global_phase_matrix_gate_decomposition_consistency( sub_gate=cirq.GlobalPhaseGate(phase_value), num_controls=num_controls, control_values=control_values, - control_qid_shape=control_qid_shape + control_qid_shape=control_qid_shape, ) cg_matrix = cirq.ControlledGate( sub_gate=cirq.MatrixGate(np.array([[phase_value]])), num_controls=num_controls, control_values=control_values, - control_qid_shape=control_qid_shape + control_qid_shape=control_qid_shape, ) decomp_global = cirq.decompose_once(cg_global(*all_qubits)) decomp_matrix = cirq.decompose_once(cg_matrix(*all_qubits)) - + assert decomp_global is not None assert decomp_matrix is not None assert decomp_global == decomp_matrix