Skip to content
Merged
49 changes: 21 additions & 28 deletions cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,34 +152,27 @@ def _decompose_with_context_(
# Prefer the subgate controlled version if available
if self != controlled_sub_gate:
return controlled_sub_gate.on(*qubits)
if (
protocols.has_unitary(self.sub_gate)
and protocols.num_qubits(self.sub_gate) == 1
and self._qid_shape_() == (2,) * len(self._qid_shape_())
and isinstance(self.control_values, cv.ProductOfSums)
):
invert_ops: list[cirq.Operation] = []
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
invert_ops.append(common_gates.X(cqbit))
elif set(cvals) == {0, 1}:
control_qubits.remove(cqbit)
decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops
if isinstance(self.sub_gate, gp.GlobalPhaseGate):
# A controlled global phase is a diagonal gate, where each active control value index
# is set equal to the phase angle.
shape = self.control_qid_shape
if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}:
# Could work in theory, but DiagonalGate decompose does not support them.
return NotImplemented
angle = np.angle(complex(self.sub_gate.coefficient))
rads = np.zeros(shape=shape)
for hot in self.control_values.expand():
rads[hot] = angle
return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits)
if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits):
n_qubits = protocols.num_qubits(self.sub_gate)
# Case 1: Global Phase (1x1 Matrix)
if n_qubits == 0:
angle = np.angle(protocols.unitary(self.sub_gate)[0, 0])
rads = np.zeros(shape=self.control_qid_shape)
for hot in self.control_values.expand():
rads[hot] = angle
return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits)
# Case 2: Multi-controlled single-qubit gate decomposition
if n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums):
invert_ops: list[cirq.Operation] = []
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
invert_ops.append(common_gates.X(cqbit))
elif set(cvals) == {0, 1}:
control_qubits.remove(cqbit)
decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops
if isinstance(self.sub_gate, common_gates.CZPowGate):
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent)
num_controls = self.num_controls() + 1
Expand Down
29 changes: 29 additions & 0 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,32 @@ def test_controlled_mixture():
c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1)
assert cirq.has_mixture(c_yes)
assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))])


@pytest.mark.parametrize(
'num_controls, angle, control_values',
[
(1, np.pi / 4, ((1,),)),
(3, -np.pi / 2, ((1,), (1,), (1,))),
(2, 0.0, ((1,), (1,))),
(2, np.pi / 5, ((0,), (0,))),
(3, np.pi, ((1,), (0,), (1,))),
(4, -np.pi / 3, ((0,), (1,), (1,), (0,))),
],
)
def test_controlled_global_phase_matrix_gate_decomposes(num_controls, angle, control_values):
all_qubits = cirq.LineQubit.range(num_controls)
control_values = cirq.ops.control_values.ProductOfSums(control_values)
control_qid_shape = (2,) * num_controls
phase_value = np.exp(1j * angle)

cg_matrix = cirq.ControlledGate(
sub_gate=cirq.MatrixGate(np.array([[phase_value]])),
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)

decomposed = cirq.decompose(cg_matrix(*all_qubits))
assert not any(isinstance(op.gate, cirq.MatrixGate) for op in decomposed)
np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix))