Skip to content

Fix complex vjps for several unary ops#3766

Open
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/complex-unary-vjp
Open

Fix complex vjps for several unary ops#3766
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/complex-unary-vjp

Conversation

@obchain

@obchain obchain commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Proposed changes

Fixes #3765.

Square, Sin, Sinh, Cosh, Tan, Tanh, and Log1p delegated their vjp to their jvp. For a holomorphic f the jvp is f'(z) * t, but the vjp must be cotangent * conj(f'(z)) (the convention already used by Exp/Log). Delegating dropped the conjugate, so gradients through these ops were wrong for complex inputs — e.g. the gradient of Re(z**2) was 2*z instead of 2*conj(z).

Each vjp now conjugates the cotangent, runs the jvp, and conjugates the result — i.e. conj(jvp(conj(w))) = w * conj(f'(z)) — which is a no-op for real inputs (so real gradients and the jvps are unchanged).

Verified each op's complex gradient against finite differences and against the hand-computed cotangent * conj(f'(z)); real gradients and all jvps are unchanged. Added test_complex_unary_vjps.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Square, Sin, Sinh, Cosh, Tan, Tanh, and Log1p delegated their vjp to
their jvp. For a holomorphic f the jvp is f'(z) * t, but the vjp must be
cotangent * conj(f'(z)); delegating dropped the conjugate, so the
gradient through these ops was wrong for complex inputs (e.g. the grad of
Re(z**2) was 2*z instead of 2*conj(z)). Conjugate the cotangent into the
jvp and conjugate the result, which yields cotangent * conj(f'(z)) and is
a no-op for real inputs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Complex vjps are wrong for square/sin/sinh/cosh/tan/tanh/log1p (missing conjugate)

1 participant