From db05e3f769d91e5f713dc8e22a3f931e2ddecdd2 Mon Sep 17 00:00:00 2001 From: obchain Date: Thu, 25 Jun 2026 12:51:51 +0530 Subject: [PATCH] Fix complex vjps for several unary ops Square, Sin, Sinh, Cosh, Tan, Tanh, and Log1p delegated their vjp to their jvp. For a holomorphic f the jvp is f'(z) * t, but the vjp must be cotangent * conj(f'(z)); delegating dropped the conjugate, so the gradient through these ops was wrong for complex inputs (e.g. the grad of Re(z**2) was 2*z instead of 2*conj(z)). Conjugate the cotangent into the jvp and conjugate the result, which yields cotangent * conj(f'(z)) and is a no-op for real inputs. --- mlx/primitives.cpp | 35 ++++++++++++++++++++++++++++------- python/tests/test_autograd.py | 22 ++++++++++++++++++++++ 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 75bae1ba87..6aa7ed8f15 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1701,7 +1701,10 @@ std::vector Cosh::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + // The vjp conjugates the jvp's multiplier (a no-op for real inputs). + return {conjugate( + jvp(primals, {conjugate(cotangents[0], stream())}, argnums)[0], + stream())}; } std::vector Cosh::jvp( @@ -2775,7 +2778,10 @@ std::vector Log1p::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + // The vjp conjugates the jvp's multiplier (a no-op for real inputs). + return {conjugate( + jvp(primals, {conjugate(cotangents[0], stream())}, argnums)[0], + stream())}; } std::vector Log1p::jvp( @@ -4867,7 +4873,10 @@ std::vector Sin::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + // The vjp conjugates the jvp's multiplier (a no-op for real inputs). + return {conjugate( + jvp(primals, {conjugate(cotangents[0], stream())}, argnums)[0], + stream())}; } std::vector Sin::jvp( @@ -4892,7 +4901,10 @@ std::vector Sinh::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + // The vjp conjugates the jvp's multiplier (a no-op for real inputs). + return {conjugate( + jvp(primals, {conjugate(cotangents[0], stream())}, argnums)[0], + stream())}; } std::vector Sinh::jvp( @@ -5435,7 +5447,10 @@ std::vector Square::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + // The vjp conjugates the jvp's multiplier (a no-op for real inputs). + return {conjugate( + jvp(primals, {conjugate(cotangents[0], stream())}, argnums)[0], + stream())}; } std::vector Square::jvp( @@ -5607,7 +5622,10 @@ std::vector Tan::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + // The vjp conjugates the jvp's multiplier (a no-op for real inputs). + return {conjugate( + jvp(primals, {conjugate(cotangents[0], stream())}, argnums)[0], + stream())}; } std::vector Tan::jvp( @@ -5633,7 +5651,10 @@ std::vector Tanh::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + // The vjp conjugates the jvp's multiplier (a no-op for real inputs). + return {conjugate( + jvp(primals, {conjugate(cotangents[0], stream())}, argnums)[0], + stream())}; } std::vector Tanh::jvp( diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index af1b7855e6..3366b7ca78 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -1339,6 +1339,28 @@ def test_complex_log_vjp(self): # Check against hand-computed vjps self.assertTrue(mx.allclose(vjps[0], expected)) + def test_complex_unary_vjps(self): + # For a holomorphic f the vjp is cotangent * conj(f'(z)); these ops used + # to delegate to their jvp and drop the conjugate for complex inputs. + mx.random.seed(0) + z = mx.random.normal((3, 4, 5), dtype=mx.complex64) + cotangent = mx.random.normal((3, 4, 5), dtype=mx.complex64) + z = mx.where(abs(z) < 1e-3, 1e-3 + 0j, z) + + ops = { + mx.square: lambda x: 2 * x, + mx.sin: mx.cos, + mx.sinh: mx.cosh, + mx.cosh: mx.sinh, + mx.tan: lambda x: 1 / mx.cos(x) ** 2, + mx.tanh: lambda x: 1 - mx.tanh(x) ** 2, + mx.log1p: lambda x: 1 / (1 + x), + } + for fn, deriv in ops.items(): + _, (vjp,) = mx.vjp(fn, [z], [cotangent]) + expected = cotangent * mx.conj(deriv(z)) + self.assertTrue(mx.allclose(vjp, expected, atol=1e-5), msg=str(fn)) + def test_complex_abs_grad(self): mx.random.seed(0) primal = mx.random.normal((3, 4, 5), dtype=mx.complex64)