diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 915c0805c46..4ed6a03b876 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -152,34 +152,27 @@ def _decompose_with_context_( # Prefer the subgate controlled version if available if self != controlled_sub_gate: return controlled_sub_gate.on(*qubits) - if ( - protocols.has_unitary(self.sub_gate) - and protocols.num_qubits(self.sub_gate) == 1 - and self._qid_shape_() == (2,) * len(self._qid_shape_()) - and isinstance(self.control_values, cv.ProductOfSums) - ): - invert_ops: list[cirq.Operation] = [] - for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): - if set(cvals) == {0}: - invert_ops.append(common_gates.X(cqbit)) - elif set(cvals) == {0, 1}: - control_qubits.remove(cqbit) - decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation( - protocols.unitary(self.sub_gate), control_qubits, qubits[-1] - ) - return invert_ops + decomposed_ops + invert_ops - if isinstance(self.sub_gate, gp.GlobalPhaseGate): - # A controlled global phase is a diagonal gate, where each active control value index - # is set equal to the phase angle. - shape = self.control_qid_shape - if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}: - # Could work in theory, but DiagonalGate decompose does not support them. - return NotImplemented - angle = np.angle(complex(self.sub_gate.coefficient)) - rads = np.zeros(shape=shape) - for hot in self.control_values.expand(): - rads[hot] = angle - return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) + if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits): + n_qubits = protocols.num_qubits(self.sub_gate) + # Case 1: Global Phase (1x1 Matrix) + if n_qubits == 0: + angle = np.angle(protocols.unitary(self.sub_gate)[0, 0]) + rads = np.zeros(shape=self.control_qid_shape) + for hot in self.control_values.expand(): + rads[hot] = angle + return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits) + # Case 2: Multi-controlled single-qubit gate decomposition + if n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums): + invert_ops: list[cirq.Operation] = [] + for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]): + if set(cvals) == {0}: + invert_ops.append(common_gates.X(cqbit)) + elif set(cvals) == {0, 1}: + control_qubits.remove(cqbit) + decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation( + protocols.unitary(self.sub_gate), control_qubits, qubits[-1] + ) + return invert_ops + decomposed_ops + invert_ops if isinstance(self.sub_gate, common_gates.CZPowGate): z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent) num_controls = self.num_controls() + 1 diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index 121ced1df78..f970f524688 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -773,3 +773,32 @@ def test_controlled_mixture(): c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1) assert cirq.has_mixture(c_yes) assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))]) + + +@pytest.mark.parametrize( + 'num_controls, angle, control_values', + [ + (1, np.pi / 4, ((1,),)), + (3, -np.pi / 2, ((1,), (1,), (1,))), + (2, 0.0, ((1,), (1,))), + (2, np.pi / 5, ((0,), (0,))), + (3, np.pi, ((1,), (0,), (1,))), + (4, -np.pi / 3, ((0,), (1,), (1,), (0,))), + ], +) +def test_controlled_global_phase_matrix_gate_decomposes(num_controls, angle, control_values): + all_qubits = cirq.LineQubit.range(num_controls) + control_values = cirq.ops.control_values.ProductOfSums(control_values) + control_qid_shape = (2,) * num_controls + phase_value = np.exp(1j * angle) + + cg_matrix = cirq.ControlledGate( + sub_gate=cirq.MatrixGate(np.array([[phase_value]])), + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape, + ) + + decomposed = cirq.decompose(cg_matrix(*all_qubits)) + assert not any(isinstance(op.gate, cirq.MatrixGate) for op in decomposed) + np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix))