Skip to content

[JAX] WAR for CuDNN MXFP8 norm incorrect result #1700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

jberchtold-nvidia
Copy link
Collaborator

Description

Checks CuDNN version and if below the fixed version does unfused norm and MXFP8 quantize to prevent producing incorrect results.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Check CuDNN version and perform unfused layernorm in JAX if it is below the fixed version
  • If below the fixed CuDNN version, lessen the strictness of the tolerances for MXFP8 norm tests as unfused layernorm loses some precision with the intermediate cast into the input dtype if that is less than fp32

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0

@phu0ngng phu0ngng added the 2.3.0 label Apr 21, 2025
Check CuDNN version and apply unfused norm if
below a version with the fix

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia jberchtold-nvidia force-pushed the dev/jberchtold/mxfp8-norm-cudnn-version-check branch from fb6e5cd to 2e9b5bf Compare April 21, 2025 16:17
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0

Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@jberchtold-nvidia jberchtold-nvidia merged commit a1c18bc into NVIDIA:main Apr 21, 2025
21 checks passed
KshitijLakhani pushed a commit that referenced this pull request Apr 22, 2025
Check CuDNN version and apply unfused norm if
below a version with the fix

Signed-off-by: Jeremy Berchtold <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants