GPT models built with JAX
This project implements the GPT series of models using Jax and Flax's NNX library.
Install the UV python package managment library
curl -LsSf https://astral.sh/uv/install.sh | sh
The main commands available in the Makefile are:
make install- Install dependencies from lockfilemake dev- Install all dependencies including dev from lockfilemake clean- Clean build artifacts and cachemake build- Build packagemake lint- Run lintingmake format- Format codemake lab- Run Jupyter lab server from the project directory
To see all available commands and their descriptions, run: make help
The training run can be reproduced using notebooks/train_gpt2.ipynb
A machine with 8 x Nvidia A100 80GB GPUs used to train for 1 epoch on a 10bn token sample of the
Fineweb-Edu dataset. Validation was performed on 1% of the dataset.
The trained model was evaluated on the Hellaswag benchmark.
The trained model achieved a score of 0.3025 on the Hellaswag benchmark.
This project includes a variety of transformer-based language models with different architectures and optimizations:
- GPT - Standard GPT-2 style transformer with causal self-attention and MLP layers
- VGPT - Variant of GPT with similar architecture but different initialization and configuration
- NoPE_GPT - GPT model without positional embeddings, for experiments on sequence modeling
- GLU_GPT - GPT with Gated Linear Units (GLU) in the MLP layers for improved efficiency
- RoPE_GPT - GPT incorporating Rotary Position Embeddings (RoPE) for better positional encoding
- GLU_RoPE_GPT - Combines GLU activation with RoPE for enhanced performance
- GQA_GPT - GPT with Grouped Query Attention for reduced memory usage during inference
- GQA_GLU_RoPE_GPT - Advanced variant combining GQA, GLU, and RoPE for optimal efficiency
- Tiny_MoE - Compact Mixture of Experts model with alternating MoE and GLU blocks, includes load balancing losses
- Tiny_MoE_2 - Fully MoE-based architecture with uniform expert layers and auxiliary loss coefficients
- Tiny_MoE_3 - Uses Soft MoE routing mechanism with alternating MoE and GLU block structure
- SmolLM - Lightweight language model inspired by small LLaMA variants with RMSNorm and efficient attention
- Charformer - Character-level transformer model designed for text generation at character granularity
- Bigram - Simple bigram language model for baseline comparisons and educational purposes
- Precision: float32 (dtype & param_dtype)
- Sharding: Manual partition specs (FSDP-like) for embed, pos_embed, ln, mlp, attn layers
- Attention: CausalSelfAttention with XLA/CUDNN SDPA implementation
- Activation: Standard MLP (no explicit activation specified)
- Positional: Learnable positional embeddings
- Normalization: LayerNorm with epsilon 1e-5
- Precision: float32 (dtype & param_dtype)
- Sharding: Manual partition specs (identical to GPT)
- Attention: CausalSelfAttention with XLA/CUDNN SDPA implementation
- Activation: Standard MLP
- Positional: Learnable positional embeddings
- Special: Compatible with VGPT2 HuggingFace checkpoints
- Precision: float32
- Sharding: No explicit sharding specifications
- Attention: CausalSelfAttention with XLA implementation
- Activation: Standard MLP with ReLU
- Positional: None (no positional embeddings)
- Purpose: Baseline model for positional encoding ablation studies
- Precision: float32 (dtype & param_dtype)
- Sharding: Manual partition specs for GLU components (glu_fc, glu_gate, glu_proj)
- Attention: CausalSelfAttention (no RoPE)
- Activation: SiLU (default), supports GELU, Sigmoid
- Positional: Learnable positional embeddings
- MLP: Gated Linear Unit (GLU) architecture
- Precision: float32 (dtype & param_dtype)
- Sharding: Manual partition specs for all components
- Attention: CausalSelfAttention_w_RoPE with theta=1e-4
- Activation: Standard MLP
- Positional: Rotary Position Embeddings (RoPE) using Llama-style calculation
- Special: Pre-calculated RoPE thetas for efficiency
- Precision: float32 (dtype & param_dtype)
- Sharding: Manual partition specs for GLU + RoPE components
- Attention: CausalSelfAttention_w_RoPE with theta=1e-4
- Activation: SiLU (default), supports GELU, Sigmoid
- Positional: Rotary Position Embeddings
- MLP: Gated Linear Unit with RoPE integration
- Precision: float32 (dtype & param_dtype)
- Sharding: Manual partition specs for GQA components (wq, wkv, wproj)
- Attention: Grouped Query Attention (9 query heads, 3 KV heads = 3:1 ratio)
- Activation: Standard MLP
- Positional: Learnable positional embeddings
- Memory: KV caching support (use_cache=False default)
- Optimization: 3x reduction in KV cache memory vs full attention
- Precision: float32 (dtype & param_dtype)
- Sharding: Manual partition specs for GQA + GLU components
- Attention: Grouped Query Attention with RoPE (12 query, 4 KV heads = 3:1 ratio)
- Activation: SiLU (default), supports GELU, Sigmoid
- Positional: Rotary Position Embeddings with theta=1e-4
- Memory: KV caching support
- MLP: Gated Linear Unit
- Efficiency: Combines GQA memory savings with RoPE performance
- Precision: bfloat16 (computation), float32 (parameters)
- Sharding: Manual partition specs + device mesh support
- Attention: GQ_Attention_w_RoPE (12 query, 4 KV heads)
- Architecture: Alternating MoE and GLU blocks (50% each)
- MoE Configuration: 8 experts, top-k=2, load factor=1.1
- Activation: SiLU
- Normalization: RMSNorm (epsilon 1e-6)
- Losses: Load balancing (1e-2) + Z-loss (1e-3)
- Routing: Expert weight priority enabled
- Positional: RoPE with theta=1e-4
- Precision: bfloat16 (computation), float32 (parameters)
- Sharding: Manual partition specs for MoE components
- Attention: GQ_Attention_w_RoPE (12 query, 4 KV heads) with QK-Norm
- Architecture: Fully MoE (all layers use MoE, no GLU blocks)
- MoE Configuration: 8 experts, top-k=2, load factor=1.25
- Activation: Squared ReLU (ReGLU²)
- Normalization: RMSNorm (epsilon 1e-5)
- Losses: Load balancing (1e-2) + Z-loss (5e-4)
- Routing: No expert weight priority
- Positional: RoPE with theta=1e-4
- Special: QK-Norm applied to queries and keys before RoPE
- Precision: bfloat16 (computation), float32 (parameters)
- Sharding: Manual partition specs
- Attention: GQ_Attention_w_RoPE (12 query, 4 KV heads)
- Architecture: Alternating Soft MoE and GLU blocks
- MoE Configuration: 8 experts using Soft MoE routing
- Activation: SiLU
- Normalization: RMSNorm (epsilon 1e-6)
- Routing: Soft MoE (no auxiliary losses)
- Positional: RoPE with theta=1e-4
- Special: Uses differentiable Soft MoE instead of hard routing
- Precision: float32
- Sharding: Manual partition specs + device mesh support
- Attention: GQ_Attention_w_RoPE (9 query, 3 KV heads = 3:1 ratio)
- Activation: SiLU (default), supports GELU, Sigmoid
- Positional: RoPE with theta=1e-4
- MLP: Gated Linear Unit
- Normalization: RMSNorm (epsilon 1e-5)
- Memory: KV caching support
- Integration: HuggingFace SmoL LM checkpoint loading/conversion
- Architecture: LLaMA-style compact design
- Precision: bfloat16 (hardcoded)
- Sharding: 2D data/model parallel mesh (DATA_DIMS=2, MODEL_DIMS=4)
- Attention: Flash Attention via Pallas TPU ops
- Activation: ReLU in MLP
- Positional: Learnable positional embeddings
- Normalization: LayerNorm
- Memory: KV caching with sequence masking
- Special: Character-level processing, TPU-optimized
- Scale: 30 layers, 960 embed dim, 15 heads
- Precision: float32 (default)
- Sharding: None (single device)
- Attention: None (simple lookup model)
- Activation: None
- Architecture: Simple bigram language model
- Purpose: Baseline comparison, educational tool
- Complexity: Minimal computational overhead
- All models support JAX/Flax NNX framework
- Multiple precision modes (float32, bfloat16, mixed precision)
- Various sharding strategies (manual FSDP-like, device mesh, data/model parallel)
- Advanced attention mechanisms (causal, grouped query, RoPE-enhanced, flash attention)
- Optional QK-Norm for attention layers
- Multiple activation functions (GELU, SiLU, Sigmoid, ReLU, Squared ReLU)
- Different normalization approaches (LayerNorm, RMSNorm)
- KV caching for efficient inference
- Comprehensive checkpoint management
- HuggingFace integration for select models

