Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Both models were trained using our [harmony response format][harmony] and should
- [Reference PyTorch implementation](#reference-pytorch-implementation)
- [Reference Triton implementation (single GPU)](#reference-triton-implementation-single-gpu)
- [Reference Metal implementation](#reference-metal-implementation)
- [Reference JAX implementation](#reference-jax-implementation)
- [Harmony format & tools](#harmony-format--tools)
- [Clients](#clients)
- [Tools](#tools)
Expand Down Expand Up @@ -210,6 +211,7 @@ This repository provides a collection of reference implementations:
- [`torch`](#reference-pytorch-implementation) — a non-optimized [PyTorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4× H100 GPUs due to lack of optimization.
- [`triton`](#reference-triton-implementation-single-gpu) — a more optimized implementation using [PyTorch](https://pytorch.org/) & [Triton](https://github.com/triton-lang/triton) incl. using CUDA graphs and basic caching
- [`metal`](#reference-metal-implementation) — a Metal-specific implementation for running the models on Apple Silicon hardware
- [`jax`](#reference-jax-implementation) — a [JAX](https://jax.readthedocs.io/)/Flax implementation for CPU inference on Apple Silicon and x86-64
- **Tools:**
- [`browser`](#browser) — a reference implementation of the browser tool the models got trained on
- [`python`](#python) — a stateless reference implementation of the python tool the model got trained on
Expand Down Expand Up @@ -237,6 +239,8 @@ pip install gpt-oss
pip install gpt-oss[torch]
# if you want to try the triton implementation
pip install gpt-oss[triton]
# if you want to try the jax implementation
pip install gpt-oss[jax]
```

If you want to modify the code or try the metal implementation set the project up locally:
Expand Down Expand Up @@ -332,6 +336,26 @@ To test it you can run:
python gpt_oss/metal/examples/generate.py gpt-oss-20b/metal/model.bin -p "why did the chicken cross the road?"
```

## Reference JAX implementation

We include a JAX/Flax reference implementation for CPU inference on Apple Silicon and x86-64. To install:

```shell
pip install -e ".[jax]"
```

For faster loading (~18x speedup), optionally convert SafeTensors to Orbax format:

```shell
python -m gpt_oss.jax --input gpt-oss-20b/original/ --output gpt-oss-20b-orbax/
```

Then run inference (supports both SafeTensors and Orbax formats):

```shell
python -m gpt_oss.generate --backend jax gpt-oss-20b-orbax/ -p "why did the chicken cross the road?"
```

## Harmony format & tools

Along with the model, we are also releasing a new chat format library `harmony` to interact with the model. Check [this guide](https://cookbook.openai.com/articles/openai-harmony) for more info about harmony.
Expand Down
7 changes: 5 additions & 2 deletions gpt_oss/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def main(args):
case "vllm":
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size)
case "jax":
from gpt_oss.jax.token_generator import TokenGenerator as JAXGenerator
generator = JAXGenerator(args.checkpoint, max_context_length=args.context_length)
case _:
raise ValueError(f"Invalid backend: {args.backend}")

Expand All @@ -43,7 +46,7 @@ def main(args):
"checkpoint",
metavar="FILE",
type=str,
help="Path to the SafeTensors checkpoint",
help="Path to the checkpoint (SafeTensors for torch/triton/vllm, SafeTensors or Orbax for jax)",
)
parser.add_argument(
"-p",
Expand Down Expand Up @@ -75,7 +78,7 @@ def main(args):
metavar="BACKEND",
type=str,
default="torch",
choices=["triton", "torch", "vllm"],
choices=["triton", "torch", "vllm", "jax"],
help="Inference backend",
)
parser.add_argument(
Expand Down
20 changes: 20 additions & 0 deletions gpt_oss/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""JAX/Flax implementation for gpt-oss inference.

This package provides a JAX-based inference implementation for gpt-oss models,
optimized for CPU execution on Apple Silicon (ARM64) and x86-64 platforms.

Key features:
- BF16 precision throughout
- Non-quantized KV caching for efficient autoregressive generation
- Supports both SafeTensors and Orbax checkpoint formats
- MXFP4 weight decompression for MoE expert weights
"""

__all__ = [
'ModelConfig',
'Transformer',
'generate',
'get_tokenizer',
'WeightLoader',
'OrbaxWeightLoader',
]
6 changes: 6 additions & 0 deletions gpt_oss/jax/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Entry point for running gpt_oss.jax conversion script as a module."""

from .scripts.convert_checkpoint import main

if __name__ == "__main__":
main()
109 changes: 109 additions & 0 deletions gpt_oss/jax/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Model configuration for gpt-oss-20b.

This configuration is identical to the PyTorch reference implementation,
ensuring compatibility when loading weights and comparing outputs.
"""

from dataclasses import dataclass


@dataclass
class ModelConfig:
"""Configuration for the gpt-oss-20b model architecture.

Attributes:
num_hidden_layers: Number of transformer layers
num_experts: Total number of experts in MoE layers
experts_per_token: Number of experts activated per token
vocab_size: Size of the vocabulary
hidden_size: Dimension of hidden states
intermediate_size: Dimension of MLP intermediate layer
swiglu_limit: Clipping limit for SwiGLU activation
head_dim: Dimension of each attention head
num_attention_heads: Number of attention heads (query)
num_key_value_heads: Number of key/value heads (GQA)
sliding_window: Sliding window size for local attention
initial_context_length: Initial context length for RoPE
rope_theta: Base frequency for RoPE
rope_scaling_factor: Scaling factor for extended context (YaRN)
rope_ntk_alpha: NTK alpha parameter for frequency interpolation
rope_ntk_beta: NTK beta parameter for frequency extrapolation
"""
num_hidden_layers: int = 36
num_experts: int = 128
experts_per_token: int = 4
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
swiglu_limit: float = 7.0
head_dim: int = 64
num_attention_heads: int = 64
num_key_value_heads: int = 8
sliding_window: int = 128
initial_context_length: int = 4096
rope_theta: float = 150000.0
rope_scaling_factor: float = 32.0
rope_ntk_alpha: float = 1.0
rope_ntk_beta: float = 32.0

def __post_init__(self):
"""Validate configuration parameters."""
# Positive value checks
assert self.num_hidden_layers > 0, \
f"num_hidden_layers must be positive, got {self.num_hidden_layers}"
assert self.num_experts > 0, \
f"num_experts must be positive, got {self.num_experts}"
assert self.experts_per_token > 0, \
f"experts_per_token must be positive, got {self.experts_per_token}"
assert self.vocab_size > 0, \
f"vocab_size must be positive, got {self.vocab_size}"
assert self.hidden_size > 0, \
f"hidden_size must be positive, got {self.hidden_size}"
assert self.intermediate_size > 0, \
f"intermediate_size must be positive, got {self.intermediate_size}"
assert self.head_dim > 0, \
f"head_dim must be positive, got {self.head_dim}"
assert self.num_attention_heads > 0, \
f"num_attention_heads must be positive, got {self.num_attention_heads}"
assert self.num_key_value_heads > 0, \
f"num_key_value_heads must be positive, got {self.num_key_value_heads}"

# Logical constraints
assert self.experts_per_token <= self.num_experts, \
f"experts_per_token ({self.experts_per_token}) cannot exceed num_experts ({self.num_experts})"
assert self.num_attention_heads % self.num_key_value_heads == 0, \
f"num_attention_heads ({self.num_attention_heads}) must be divisible by " \
f"num_key_value_heads ({self.num_key_value_heads})"
assert self.intermediate_size % 2 == 0, \
f"intermediate_size must be even for SwiGLU, got {self.intermediate_size}"

# Sliding window check
assert self.sliding_window >= 0, \
f"sliding_window must be non-negative, got {self.sliding_window}"

# RoPE parameter checks
assert self.rope_theta > 0, \
f"rope_theta must be positive, got {self.rope_theta}"
assert self.rope_scaling_factor >= 1.0, \
f"rope_scaling_factor must be >= 1.0, got {self.rope_scaling_factor}"
assert self.rope_ntk_alpha > 0, \
f"rope_ntk_alpha must be positive, got {self.rope_ntk_alpha}"
assert self.rope_ntk_beta > 0, \
f"rope_ntk_beta must be positive, got {self.rope_ntk_beta}"
assert self.initial_context_length > 0, \
f"initial_context_length must be positive, got {self.initial_context_length}"

@property
def q_mult(self) -> int:
"""Number of query heads per key/value head (GQA multiplier)."""
return self.num_attention_heads // self.num_key_value_heads

@property
def total_attention_dim(self) -> int:
"""Total dimension of all attention heads."""
return self.num_attention_heads * self.head_dim

@property
def qkv_dim(self) -> int:
"""Total dimension of concatenated Q, K, V projections."""
return self.head_dim * (self.num_attention_heads + 2 * self.num_key_value_heads)
Loading