diff --git a/RELEASES.md b/RELEASES.md index ae4c3fadc..ff9fe13f5 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### Closed issues - Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770) - Add test for build from source (PR #772, Issue #764) +- Stable `ot.TorchBackend.sqrtm` around repeated eigvals (PR #774, Issue #773) ## 0.9.6.post1 diff --git a/ot/backend.py b/ot/backend.py index f14da588b..a11c78209 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1938,6 +1938,7 @@ def __init__(self): self.rng_cuda_ = torch.Generator("cpu") from torch.autograd import Function + from torch.autograd.function import once_differentiable # define a function that takes inputs val and grads # ad returns a val tensor with proper gradients @@ -1952,7 +1953,31 @@ def backward(ctx, grad_output): # the gradients are grad return (None, None) + tuple(g * grad_output for g in ctx.grads) + # define a differentiable SPD matrix sqrt + # with closed-form VJP + class MatrixSqrtFunction(Function): + @staticmethod + def forward(ctx, a): + a_sym = 0.5 * (a + a.transpose(-2, -1)) + L, V = torch.linalg.eigh(a_sym) + s = L.clamp_min(0).sqrt() + y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1) + ctx.save_for_backward(s, V) + return y + + @staticmethod + @once_differentiable + def backward(ctx, g): + s, V = ctx.saved_tensors + g_sym = 0.5 * (g + g.transpose(-2, -1)) + ghat = V.transpose(-2, -1) @ g_sym @ V + d = s.unsqueeze(-1) + s.unsqueeze(-2) + xhat = ghat / d + xhat = xhat.masked_fill(d == 0, 0) + return V @ xhat @ V.transpose(-2, -1) + self.ValFunction = ValFunction + self.MatrixSqrtFunction = MatrixSqrtFunction def _to_numpy(self, a): if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): @@ -2395,12 +2420,7 @@ def pinv(self, a, hermitian=False): return torch.linalg.pinv(a, hermitian=hermitian) def sqrtm(self, a): - L, V = torch.linalg.eigh(a) - L = torch.sqrt(L) - # Q[...] = V[...] @ diag(L[...]) - Q = torch.einsum("...jk,...k->...jk", V, L) - # R[...] = Q[...] @ V[...].T - return torch.einsum("...jk,...kl->...jl", Q, torch.transpose(V, -1, -2)) + return self.MatrixSqrtFunction.apply(a) def eigh(self, a): return torch.linalg.eigh(a) diff --git a/test/test_backend.py b/test/test_backend.py index 994895fda..2a0fc9a48 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -822,6 +822,19 @@ def fun(a, b, d): assert nx.allclose(dl_db, b) +def test_sqrtm_backward_torch(): + if not torch: + pytest.skip("Torch not available") + nx = ot.backend.TorchBackend() + torch.manual_seed(42) + d = 5 + A = torch.randn(d, d, dtype=torch.float64, device="cpu") + A = A @ A.T + A.requires_grad_(True) + func = lambda x: nx.sqrtm(x).sum() + assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4) + + def test_get_backend_none(): a, b = np.zeros((2, 3)), None nx = get_backend(a, b)