Skip to content

Conversation

@atveit
Copy link

@atveit atveit commented Oct 22, 2025

Add JAX backend for GPT-OSS. This implementation uses BF16 precision with non-quantized KV caching.

Key features:

  • Clean, simplified implementation without experimental optimizations
  • Support for conversion from SafeTensors to Orbax checkpoint formats
  • Automatic checkpoint format detection
  • Integrated with existing gpt_oss.generate interface

Structure:

  • gpt_oss/jax/: Core model and inference files
  • gpt_oss/jax/scripts/: Checkpoint conversion utilities
  • Clear file naming: loader_safetensors.py, loader_orbax.py

Usage:
# Convert checkpoint (optional, for faster loading) python -m gpt_oss.jax --input gpt-oss-20b/original/ --output gpt-oss-20b-orbax/

# Run inference python -m gpt_oss.generate --backend jax gpt-oss-20b-orbax/ -p "your prompt"

Files: 136 KB across 14 Python files in gpt_oss/jax/

This follows the convention of:

  • Using feat: prefix for new features
  • Clear, concise summary line
  • Detailed description of what and why
  • Key features listed
  • Usage examples
  • Metrics (file size/count)

  Add JAX backend for CPU-based inference on Apple Silicon and x86-64 platforms.
  This implementation uses BF16 precision with non-quantized KV caching for efficient
  autoregressive generation.

  Key features:
  - Clean, simplified implementation without experimental optimizations
  - Support for both SafeTensors and Orbax checkpoint formats
  - Fast Orbax loading (~18x speedup: 5s vs 90s)
  - Automatic checkpoint format detection
  - Integrated with existing gpt_oss.generate interface

  Structure:
  - gpt_oss/jax/: Core model and inference files
  - gpt_oss/jax/scripts/: Checkpoint conversion utilities
  - Clear file naming: loader_safetensors.py, loader_orbax.py

  Usage:
    # Convert checkpoint (optional, for faster loading)
    python -m gpt_oss.jax --input gpt-oss-20b/original/ --output gpt-oss-20b-orbax/

    # Run inference
    python -m gpt_oss.generate --backend jax gpt-oss-20b-orbax/ -p "your prompt"

  Files: 136 KB across 14 Python files in gpt_oss/jax/

  This follows the convention of:
  - Using feat: prefix for new features
  - Clear, concise summary line
  - Detailed description of what and why
  - Key features listed
  - Usage examples
  - Metrics (file size/count)
@atveit atveit changed the title feat: add JAX/Flax reference implementation for CPU inference feat: add JAX/Flax reference implementation for inference Oct 22, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Amund Tveit added 2 commits October 22, 2025 13:52
….lax.cond for token sampling and jax.lax.dynamic_update_slice for KV cache updates. Add @jax.jit decorators to performance-critical functions (token sampling, cache extension, RoPE, SDPA) while removing 40+ assert statements that prevent JIT compilation.
… KVCache as PyTree

  - Detect Orbax vs SafeTensors format before attempting to load config.json
  - Use load_config_from_orbax fallback for Orbax checkpoints without config files
  - Register KVCache as JAX PyTree to enable JIT compilation with KV caching
  - Fixes TypeError when using experimental jit_generate_loop=True mode

  Addresses code review feedback on commit c85ba18
@atsentia
Copy link

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

atveit and others added 2 commits October 23, 2025 10:18
…fig.json instead of hardcoding values to support both gpt-oss-20B and gpt-oss-120B models
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@atveit
Copy link
Author

atveit commented Oct 23, 2025

@codex review - believe all issues have been resolved now?

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. Hooray!

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants