Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#### Closed issues
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
- Add test for build from source (PR #772, Issue #764)
- Stable `ot.TorchBackend.sqrtm` around repeated eigvals (PR #774, Issue #773)

## 0.9.6.post1

Expand Down
32 changes: 26 additions & 6 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,7 @@ def __init__(self):
self.rng_cuda_ = torch.Generator("cpu")

from torch.autograd import Function
from torch.autograd.function import once_differentiable

# define a function that takes inputs val and grads
# ad returns a val tensor with proper gradients
Expand All @@ -1952,7 +1953,31 @@ def backward(ctx, grad_output):
# the gradients are grad
return (None, None) + tuple(g * grad_output for g in ctx.grads)

# define a differentiable SPD matrix sqrt
# with closed-form VJP
class MatrixSqrtFunction(Function):
@staticmethod
def forward(ctx, a):
a_sym = 0.5 * (a + a.transpose(-2, -1))
L, V = torch.linalg.eigh(a_sym)
s = L.clamp_min(0).sqrt()
y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1)
ctx.save_for_backward(s, V)
return y

@staticmethod
@once_differentiable
def backward(ctx, g):
s, V = ctx.saved_tensors
g_sym = 0.5 * (g + g.transpose(-2, -1))
ghat = V.transpose(-2, -1) @ g_sym @ V
d = s.unsqueeze(-1) + s.unsqueeze(-2)
xhat = ghat / d
xhat = xhat.masked_fill(d == 0, 0)
return V @ xhat @ V.transpose(-2, -1)

self.ValFunction = ValFunction
self.MatrixSqrtFunction = MatrixSqrtFunction

def _to_numpy(self, a):
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
Expand Down Expand Up @@ -2395,12 +2420,7 @@ def pinv(self, a, hermitian=False):
return torch.linalg.pinv(a, hermitian=hermitian)

def sqrtm(self, a):
L, V = torch.linalg.eigh(a)
L = torch.sqrt(L)
# Q[...] = V[...] @ diag(L[...])
Q = torch.einsum("...jk,...k->...jk", V, L)
# R[...] = Q[...] @ V[...].T
return torch.einsum("...jk,...kl->...jl", Q, torch.transpose(V, -1, -2))
return self.MatrixSqrtFunction.apply(a)

def eigh(self, a):
return torch.linalg.eigh(a)
Expand Down
13 changes: 13 additions & 0 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,19 @@ def fun(a, b, d):
assert nx.allclose(dl_db, b)


def test_sqrtm_backward_torch():
if not torch:
pytest.skip("Torch not available")
nx = ot.backend.TorchBackend()
torch.manual_seed(42)
d = 5
A = torch.randn(d, d, dtype=torch.float64, device="cpu")
A = A @ A.T
A.requires_grad_(True)
func = lambda x: nx.sqrtm(x).sum()
assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4)


def test_get_backend_none():
a, b = np.zeros((2, 3)), None
nx = get_backend(a, b)
Expand Down
Loading