Skip to content

Improve controlled gate decomposition logic and add targeted unit tests #7283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,34 +161,27 @@ def _decompose_with_context_(
self, qubits: Tuple[cirq.Qid, ...], context: Optional[cirq.DecompositionContext] = None
) -> Union[None, NotImplementedType, cirq.OP_TREE]:
control_qubits = list(qubits[: self.num_controls()])
if (
protocols.has_unitary(self.sub_gate)
and protocols.num_qubits(self.sub_gate) == 1
and self._qid_shape_() == (2,) * len(self._qid_shape_())
and isinstance(self.control_values, cv.ProductOfSums)
):
invert_ops: List[cirq.Operation] = []
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
invert_ops.append(common_gates.X(cqbit))
elif set(cvals) == {0, 1}:
control_qubits.remove(cqbit)
decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops
if isinstance(self.sub_gate, gp.GlobalPhaseGate):
# A controlled global phase is a diagonal gate, where each active control value index
# is set equal to the phase angle.
shape = self.control_qid_shape
if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}:
# Could work in theory, but DiagonalGate decompose does not support them.
return NotImplemented
angle = np.angle(complex(self.sub_gate.coefficient))
rads = np.zeros(shape=shape)
for hot in self.control_values.expand():
rads[hot] = angle
return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits)
if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits):
n_qubits = protocols.num_qubits(self.sub_gate)
# Case 1: Global Phase (1x1 Matrix)
if n_qubits == 0:
angle = np.angle(protocols.unitary(self.sub_gate)[0, 0])
rads = np.zeros(shape=self.control_qid_shape)
for hot in self.control_values.expand():
rads[hot] = angle
return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits)
# Case 2: Multi-controlled single-qubit gate decomposition
if n_qubits == 1 and isinstance(self.control_values, cv.ProductOfSums):
invert_ops: List[cirq.Operation] = []
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
invert_ops.append(common_gates.X(cqbit))
elif set(cvals) == {0, 1}:
control_qubits.remove(cqbit)
decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops
if isinstance(self.sub_gate, common_gates.CZPowGate):
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent)
num_controls = self.num_controls() + 1
Expand Down
41 changes: 41 additions & 0 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,44 @@ def test_controlled_mixture():
c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1)
assert cirq.has_mixture(c_yes)
assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))])


@pytest.mark.parametrize(
'num_controls, angle, control_values',
[
(1, np.pi / 4, ((1,),)),
(3, -np.pi / 2, ((1,), (1,), (1,))),
(2, 0.0, ((1,), (1,))),
(2, np.pi / 5, ((0,), (0,))),
(3, np.pi, ((1,), (0,), (1,))),
(4, -np.pi / 3, ((0,), (1,), (1,), (0,))),
],
)
def test_controlled_global_phase_matrix_gate_decomposition_consistency(
num_controls, angle, control_values
):
all_qubits = cirq.LineQubit.range(num_controls)
control_values = cirq.ops.control_values.ProductOfSums(control_values)
control_qid_shape = (2,) * num_controls
phase_value = np.exp(1j * angle)

cg_global = cirq.ControlledGate(
sub_gate=cirq.GlobalPhaseGate(phase_value),
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)

cg_matrix = cirq.ControlledGate(
sub_gate=cirq.MatrixGate(np.array([[phase_value]])),
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)

decomp_global = cirq.decompose_once(cg_global(*all_qubits))
decomp_matrix = cirq.decompose_once(cg_matrix(*all_qubits))

assert decomp_global is not None
assert decomp_matrix is not None
assert decomp_global == decomp_matrix