From 5262a35d061be3b753a9ba30bf158269e0e71ea2 Mon Sep 17 00:00:00 2001 From: Revanth Gundala Date: Mon, 21 Apr 2025 14:17:46 -0700 Subject: [PATCH 1/2] Use SVD-based fidelity for density matrices and add numerical stability test --- cirq-core/cirq/qis/measures.py | 9 ++++----- cirq-core/cirq/qis/measures_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/qis/measures.py b/cirq-core/cirq/qis/measures.py index f348d1e7145..16d514f5c23 100644 --- a/cirq-core/cirq/qis/measures.py +++ b/cirq-core/cirq/qis/measures.py @@ -242,11 +242,10 @@ def _fidelity_state_vectors_or_density_matrices(state1: np.ndarray, state2: np.n # state1 is a density matrix and state2 is a state vector return np.real(np.conjugate(state2) @ state1 @ state2) elif state1.ndim == 2 and state2.ndim == 2: - # Both density matrices - state1_sqrt = _sqrt_positive_semidefinite_matrix(state1) - eigs = linalg.eigvalsh(state1_sqrt @ state2 @ state1_sqrt) - trace = np.sum(np.sqrt(np.abs(eigs))) - return trace**2 + # Both density matrices: use SVD-based fidelity for numerical stability + rho1_sqrt = linalg.sqrtm(state1) + rho2_sqrt = linalg.sqrtm(state2) + return (np.sum(linalg.svdvals(rho1_sqrt @ rho2_sqrt))) ** 2 # matrix is reshaped before this point raise ValueError( # pragma: no cover 'The given arrays must be one- or two-dimensional. ' diff --git a/cirq-core/cirq/qis/measures_test.py b/cirq-core/cirq/qis/measures_test.py index beb42c29492..3208a282ee2 100644 --- a/cirq-core/cirq/qis/measures_test.py +++ b/cirq-core/cirq/qis/measures_test.py @@ -16,6 +16,8 @@ import pytest import cirq +from cirq import partial_trace +import cirq.qis.measures as measures N = 15 VEC1 = cirq.testing.random_superposition(N) @@ -183,6 +185,30 @@ def test_fidelity_bad_shape(): _ = cirq.fidelity(np.array([[[1.0]]]), np.array([[[1.0]]]), qid_shape=(1,)) +def test_fidelity_numerical_stability_high_dim(): + init_qubits = 10 + final_qubits = init_qubits - 1 + rng = np.random.RandomState(42) + psi = rng.randn(2**init_qubits) + 1j * rng.randn(2**init_qubits) + psi /= np.linalg.norm(psi) + rho = np.outer(psi, np.conjugate(psi)) + rho_reshaped = rho.reshape((2,) * (init_qubits * 2)) + keep_idxs = list(range(final_qubits)) + rho_reduced = partial_trace(rho_reshaped, keep_idxs).reshape((2**final_qubits,) * 2) + + # Direct fidelity computation (old) + rho1_sqrt = measures._sqrt_positive_semidefinite_matrix(rho_reduced) + eigs = measures.linalg.eigvalsh(rho1_sqrt @ rho_reduced @ rho1_sqrt) + cirq_fidelity = (np.sum(np.sqrt(np.abs(eigs)))) ** 2 + # SVD-based fidelity (patched) + get_fidelity = cirq.fidelity( + rho_reduced, rho_reduced, validate=False, qid_shape=(2,) * final_qubits + ) + # Old version should exceed 1, new should be ~1 + assert cirq_fidelity > 1 + 1e-6 + assert get_fidelity == pytest.approx(1, abs=1e-6) + + def test_von_neumann_entropy(): # 1x1 matrix assert cirq.von_neumann_entropy(np.array([[1]])) == 0 From b7918953ad009407a330b3a646558a54cbb096ad Mon Sep 17 00:00:00 2001 From: Revanth Gundala Date: Mon, 21 Apr 2025 14:35:24 -0700 Subject: [PATCH 2/2] Fixed linting --- cirq-core/cirq/qis/measures_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/qis/measures_test.py b/cirq-core/cirq/qis/measures_test.py index 3208a282ee2..232c3fb8410 100644 --- a/cirq-core/cirq/qis/measures_test.py +++ b/cirq-core/cirq/qis/measures_test.py @@ -16,8 +16,8 @@ import pytest import cirq -from cirq import partial_trace import cirq.qis.measures as measures +from cirq import partial_trace N = 15 VEC1 = cirq.testing.random_superposition(N)