From a72cde9bb95e5818e41e50f09ff6c753abdb683a Mon Sep 17 00:00:00 2001 From: obchain Date: Mon, 22 Jun 2026 16:40:29 +0530 Subject: [PATCH 1/2] Fix abs jvp for complex inputs Abs::jvp returned tangents[0] * sign(z) for complex inputs. Since |z| is real-valued, the framework dropped the imaginary part of that complex product, leaving Re(z * t) / |z| instead of the correct directional derivative Re(conj(z) * t) / |z|. Multiply by conj(sign(z)) and take the real part so the tangent is correct (and real) for complex inputs; the real path is unchanged. The vjp no longer delegates to the jvp (the two differ by a conjugate for complex inputs) but keeps its existing, correct cotangent * sign(z) form. --- mlx/primitives.cpp | 15 +++++++++++++-- python/tests/test_autograd.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3e0f2300fc..181b05f448 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -221,7 +221,11 @@ std::vector Abs::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + assert(primals.size() == 1); + assert(argnums.size() == 1); + // For a complex input the gradient of |z| is the cotangent times sign(z); + // its real and imaginary parts are the derivatives w.r.t. Re(z) and Im(z). + return {multiply(cotangents[0], sign(primals[0], stream()), stream())}; } std::vector Abs::jvp( @@ -230,7 +234,14 @@ std::vector Abs::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return {multiply(tangents[0], sign(primals[0], stream()), stream())}; + auto s = sign(primals[0], stream()); + if (issubdtype(primals[0].dtype(), complexfloating)) { + // |z| is real-valued, so its directional derivative is real: + // d|z| = Re(conj(z) * t) / |z| = Re(conj(sign(z)) * t). + return {real( + multiply(tangents[0], conjugate(s, stream()), stream()), stream())}; + } + return {multiply(tangents[0], s, stream())}; } std::pair, std::vector> Abs::vmap( diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 1ed1bc6997..c97faacd49 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -1168,6 +1168,34 @@ def test_complex_log_vjp(self): # Check against hand-computed vjps self.assertTrue(mx.allclose(vjps[0], expected)) + def test_complex_abs_grad(self): + mx.random.seed(0) + primal = mx.random.normal((3, 4, 5), dtype=mx.complex64) + # guard against values too close to the origin where |z| is not smooth + primal = mx.where(abs(primal) < 1e-3, 1e-3 + 0j, primal) + + # |z| is real-valued, so its jvp is real: + # d|z| = Re(conj(z) * t) / |z| + tangent = mx.random.normal(primal.shape, dtype=mx.complex64) + _, (jvp,) = mx.jvp(mx.abs, [primal], [tangent]) + expected = mx.real(mx.conj(primal) * tangent) / mx.abs(primal) + self.assertEqual(jvp.dtype, mx.float32) + self.assertTrue(mx.allclose(jvp, expected, atol=1e-5)) + + # The vjp's real and imaginary parts are the gradients w.r.t. Re(z) and + # Im(z); for a real cotangent this is cotangent * sign(z). + cotangent = mx.random.normal(primal.shape) + _, (vjp,) = mx.vjp(mx.abs, [primal], [cotangent]) + self.assertTrue( + mx.allclose(vjp, cotangent * (primal / mx.abs(primal)), atol=1e-5) + ) + + # Real inputs are unaffected. + x = mx.random.normal((10,)) + t = mx.random.normal((10,)) + _, (jvp,) = mx.jvp(mx.abs, [x], [t]) + self.assertTrue(mx.allclose(jvp, mx.sign(x) * t)) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From f2791b021a089a7aa5ab18007157e9d4d991b6c0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 23 Jun 2026 14:23:42 -0700 Subject: [PATCH 2/2] Remove comments from gradient calculation in Abs::vip --- mlx/primitives.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 181b05f448..a5a898d801 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -223,8 +223,6 @@ std::vector Abs::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); - // For a complex input the gradient of |z| is the cotangent times sign(z); - // its real and imaginary parts are the derivatives w.r.t. Re(z) and Im(z). return {multiply(cotangents[0], sign(primals[0], stream()), stream())}; }