-
Notifications
You must be signed in to change notification settings - Fork 1.9k
feat: add JAX/Flax reference implementation for inference #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add JAX backend for CPU-based inference on Apple Silicon and x86-64 platforms.
This implementation uses BF16 precision with non-quantized KV caching for efficient
autoregressive generation.
Key features:
- Clean, simplified implementation without experimental optimizations
- Support for both SafeTensors and Orbax checkpoint formats
- Fast Orbax loading (~18x speedup: 5s vs 90s)
- Automatic checkpoint format detection
- Integrated with existing gpt_oss.generate interface
Structure:
- gpt_oss/jax/: Core model and inference files
- gpt_oss/jax/scripts/: Checkpoint conversion utilities
- Clear file naming: loader_safetensors.py, loader_orbax.py
Usage:
# Convert checkpoint (optional, for faster loading)
python -m gpt_oss.jax --input gpt-oss-20b/original/ --output gpt-oss-20b-orbax/
# Run inference
python -m gpt_oss.generate --backend jax gpt-oss-20b-orbax/ -p "your prompt"
Files: 136 KB across 14 Python files in gpt_oss/jax/
This follows the convention of:
- Using feat: prefix for new features
- Clear, concise summary line
- Detailed description of what and why
- Key features listed
- Usage examples
- Metrics (file size/count)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
….lax.cond for token sampling and jax.lax.dynamic_update_slice for KV cache updates. Add @jax.jit decorators to performance-critical functions (token sampling, cache extension, RoPE, SDPA) while removing 40+ assert statements that prevent JIT compilation.
… KVCache as PyTree - Detect Orbax vs SafeTensors format before attempting to load config.json - Use load_config_from_orbax fallback for Orbax checkpoints without config files - Register KVCache as JAX PyTree to enable JIT compilation with KV caching - Fixes TypeError when using experimental jit_generate_loop=True mode Addresses code review feedback on commit c85ba18
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
…fig.json instead of hardcoding values to support both gpt-oss-20B and gpt-oss-120B models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
@codex review - believe all issues have been resolved now? |
|
Codex Review: Didn't find any major issues. Hooray! ℹ️ About Codex in GitHubCodex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback". |
Add JAX backend for GPT-OSS. This implementation uses BF16 precision with non-quantized KV caching.
Key features:
Structure:
Usage:
# Convert checkpoint (optional, for faster loading) python -m gpt_oss.jax --input gpt-oss-20b/original/ --output gpt-oss-20b-orbax/
Files: 136 KB across 14 Python files in gpt_oss/jax/
This follows the convention of: