diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index d194e645816..c5dc8ed975b 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -161,34 +161,27 @@ def _decompose_with_context_( self, qubits: Tuple[cirq.Qid, ...], context: Optional[cirq.DecompositionContext] = None ) -> Union[None, NotImplementedType, cirq.OP_TREE]: control_qubits = list(qubits[: self.num_controls()]) - if ( - protocols.has_unitary(self.sub_gate) - and protocols.num_qubits(self.sub_gate) == 1 - and self._qid_shape_() == (2,) * len(self._qid_shape_()) - and isinstance(self.control_values, cv.ProductOfSums) - ): - invert_ops: List[cirq.Operation] = [] - for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): - if set(cvals) == {0}: - invert_ops.append(common_gates.X(cqbit)) - elif set(cvals) == {0, 1}: - control_qubits.remove(cqbit) - decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation( - protocols.unitary(self.sub_gate), control_qubits, qubits[-1] - ) - return invert_ops + decomposed_ops + invert_ops - if isinstance(self.sub_gate, gp.GlobalPhaseGate): - # A controlled global phase is a diagonal gate, where each active control value index - # is set equal to the phase angle. - shape = self.control_qid_shape - if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}: - # Could work in theory, but DiagonalGate decompose does not support them. - return NotImplemented - angle = np.angle(complex(self.sub_gate.coefficient)) - rads = np.zeros(shape=shape) - for hot in self.control_values.expand(): - rads[hot] = angle - return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) + if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits): + n_qubits = protocols.num_qubits(self.sub_gate) + # Case 1: Global Phase (1x1 Matrix) + if n_qubits == 0: + angle = np.angle(protocols.unitary(self.sub_gate)[0, 0]) + rads = np.zeros(shape=self.control_qid_shape) + for hot in self.control_values.expand(): + rads[hot] = angle + return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) + # Case 2: Multi-controlled single-qubit gate decomposition + if n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums): + invert_ops: List[cirq.Operation] = [] + for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): + if set(cvals) == {0}: + invert_ops.append(common_gates.X(cqbit)) + elif set(cvals) == {0, 1}: + control_qubits.remove(cqbit) + decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation( + protocols.unitary(self.sub_gate), control_qubits, qubits[-1] + ) + return invert_ops + decomposed_ops + invert_ops if isinstance(self.sub_gate, common_gates.CZPowGate): z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent) num_controls = self.num_controls() + 1 diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index ebff6b9c709..da79b5a74dc 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -737,3 +737,44 @@ def test_controlled_mixture(): c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1) assert cirq.has_mixture(c_yes) assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) + + +@pytest.mark.parametrize( + 'num_controls, angle, control_values', + [ + (1, np.pi / 4, ((1,),)), + (3, -np.pi / 2, ((1,), (1,), (1,))), + (2, 0.0, ((1,), (1,))), + (2, np.pi / 5, ((0,), (0,))), + (3, np.pi, ((1,), (0,), (1,))), + (4, -np.pi / 3, ((0,), (1,), (1,), (0,))), + ], +) +def test_controlled_global_phase_matrix_gate_decomposition_consistency( + num_controls, angle, control_values +): + all_qubits = cirq.LineQubit.range(num_controls) + control_values = cirq.ops.control_values.ProductOfSums(control_values) + control_qid_shape = (2,) * num_controls + phase_value = np.exp(1j * angle) + + cg_global = cirq.ControlledGate( + sub_gate=cirq.GlobalPhaseGate(phase_value), + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape, + ) + + cg_matrix = cirq.ControlledGate( + sub_gate=cirq.MatrixGate(np.array([[phase_value]])), + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape, + ) + + decomp_global = cirq.decompose_once(cg_global(*all_qubits)) + decomp_matrix = cirq.decompose_once(cg_matrix(*all_qubits)) + + assert decomp_global is not None + assert decomp_matrix is not None + assert decomp_global == decomp_matrix