Skip to content

Fix abs jvp for complex inputs#3745

Merged
angeloskath merged 2 commits into
ml-explore:mainfrom
obchain:fix/abs-complex-jvp
Jun 23, 2026
Merged

Fix abs jvp for complex inputs#3745
angeloskath merged 2 commits into
ml-explore:mainfrom
obchain:fix/abs-complex-jvp

Conversation

@obchain

@obchain obchain commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Proposed changes

Fixes #3744.

Abs::jvp returned tangents[0] * sign(z) for complex inputs. Since |z| is real-valued, the framework dropped the imaginary part of that complex product, leaving Re(z * t) / |z| instead of the correct directional derivative

d|z| = Re(conj(z) * t) / |z| = Re(conj(sign(z)) * t)

so the imaginary contribution had the wrong sign. The fix multiplies by conj(sign(z)) and takes the real part, giving a correct (and real) tangent for complex inputs; the real path is unchanged.

Abs::vjp previously delegated to jvp. For complex inputs the vjp and jvp differ by a conjugate, so it no longer delegates — it keeps its existing, already-correct cotangent * sign(z) form (whose real/imaginary parts are the gradients w.r.t. Re(z) and Im(z)).

The bug also surfaced through ops like mx.abs(mx.fft.rfft(x)), whose jvp is now correct.

Before:

>>> z = mx.array([1+2j]); t = mx.array([0.5-1j])
>>> mx.jvp(mx.abs, [z], [t])[1][0]      # Re(z*t)/|z|, wrong
>>> mx.real(mx.conj(z)*t)/mx.abs(z)     # correct

Added test_complex_abs_grad to test_autograd.py checking the jvp against the hand-computed value, the vjp against cotangent * sign(z), and that real inputs are unaffected.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

obchain and others added 2 commits June 22, 2026 16:40
Abs::jvp returned tangents[0] * sign(z) for complex inputs. Since |z| is
real-valued, the framework dropped the imaginary part of that complex
product, leaving Re(z * t) / |z| instead of the correct directional
derivative Re(conj(z) * t) / |z|. Multiply by conj(sign(z)) and take the
real part so the tangent is correct (and real) for complex inputs; the
real path is unchanged. The vjp no longer delegates to the jvp (the two
differ by a conjugate for complex inputs) but keeps its existing, correct
cotangent * sign(z) form.

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@angeloskath angeloskath merged commit 9d0b04b into ml-explore:main Jun 23, 2026
16 checks passed
tpegolotti pushed a commit that referenced this pull request Jun 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] jvp of mx.abs is wrong for complex inputs (missing conjugate)

2 participants