Skip to content

Decompose controlled 1x1 unitary gates generically, not just GlobalPhaseGate #7283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 27, 2025
Merged
49 changes: 21 additions & 28 deletions cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,34 +152,27 @@ def _decompose_with_context_(
# Prefer the subgate controlled version if available
if self != controlled_sub_gate:
return controlled_sub_gate.on(*qubits)
if (
protocols.has_unitary(self.sub_gate)
and protocols.num_qubits(self.sub_gate) == 1
and self._qid_shape_() == (2,) * len(self._qid_shape_())
and isinstance(self.control_values, cv.ProductOfSums)
):
invert_ops: list[cirq.Operation] = []
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
invert_ops.append(common_gates.X(cqbit))
elif set(cvals) == {0, 1}:
control_qubits.remove(cqbit)
decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops
if isinstance(self.sub_gate, gp.GlobalPhaseGate):
# A controlled global phase is a diagonal gate, where each active control value index
# is set equal to the phase angle.
shape = self.control_qid_shape
if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}:
# Could work in theory, but DiagonalGate decompose does not support them.
return NotImplemented
angle = np.angle(complex(self.sub_gate.coefficient))
rads = np.zeros(shape=shape)
for hot in self.control_values.expand():
rads[hot] = angle
return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits)
if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits):
n_qubits = protocols.num_qubits(self.sub_gate)
# Case 1: Global Phase (1x1 Matrix)
if n_qubits == 0:
angle = np.angle(protocols.unitary(self.sub_gate)[0, 0])
rads = np.zeros(shape=self.control_qid_shape)
for hot in self.control_values.expand():
rads[hot] = angle
return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits)
# Case 2: Multi-controlled single-qubit gate decomposition
if n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums):
invert_ops: list[cirq.Operation] = []
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
invert_ops.append(common_gates.X(cqbit))
elif set(cvals) == {0, 1}:
control_qubits.remove(cqbit)
decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops
if isinstance(self.sub_gate, common_gates.CZPowGate):
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent)
num_controls = self.num_controls() + 1
Expand Down
29 changes: 29 additions & 0 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,32 @@ def test_controlled_mixture():
c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1)
assert cirq.has_mixture(c_yes)
assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))])


@pytest.mark.parametrize(
'num_controls, angle, control_values',
[
(1, np.pi / 4, ((1,),)),
(3, -np.pi / 2, ((1,), (1,), (1,))),
(2, 0.0, ((1,), (1,))),
(2, np.pi / 5, ((0,), (0,))),
(3, np.pi, ((1,), (0,), (1,))),
(4, -np.pi / 3, ((0,), (1,), (1,), (0,))),
],
)
def test_controlled_global_phase_matrix_gate_decomposes(num_controls, angle, control_values):
all_qubits = cirq.LineQubit.range(num_controls)
control_values = cirq.ops.control_values.ProductOfSums(control_values)
control_qid_shape = (2,) * num_controls
phase_value = np.exp(1j * angle)

cg_matrix = cirq.ControlledGate(
sub_gate=cirq.MatrixGate(np.array([[phase_value]])),
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)

decomposed = cirq.decompose(cg_matrix(*all_qubits))
assert not any(isinstance(op.gate, cirq.MatrixGate) for op in decomposed)
np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix))