Skip to content

novastar53/tiny_moe

Repository files navigation

Tiny MoE

Overview

Tiny MoE is a minimal implementation of a Mixture-of-Experts (MoE) language model using JAX and Flax NNX. It demonstrates how to build an efficient and scalable language model that uses expert routing to process tokens.

Features

  • Minimal and clean implementation of Mixture-of-Experts architecture
  • Built with JAX and Flax NNX for efficient computation
  • GLU (Gated Linear Units) and GLU-based MoE blocks with ReLU² activation
  • Zero-initialized projection layers for improved training stability
  • RoPE (Rotary Position Embedding) for better position encoding
  • Model parallelism for Expert layers, data parallelism for non-MoE layers
  • Auxiliary loss for load balancing between experts

Requirements

  • Python 3.13.1
  • Core dependencies:
    • JAX (with CUDA 12 or Metal support)
    • Flax
    • Orbax
    • Transformers

Installation

  1. Clone the repository:
git clone https://github.com/novastar53/tiny_moe.git
cd tiny_moe
  1. Install the required dependencies:

First, install the uv package manager (if not already installed):

make uv

Then, install dependencies based on your hardware:

For CUDA support (NVIDIA GPUs):

make cuda

For CPU or Apple Metal support:

make cpu

The above commands will install all required dependencies including development tools.

Model Architecture

The model consists of alternating MOE and GLU blocks:

  • MOE Block: Combines attention mechanism with mixture-of-experts routing
  • GLU Block: Uses gated linear units for non-linear transformations
  • Attention: Implements multi-head attention with RoPE positional embeddings
  • RMSNorm: Used for layer normalization

Training

The repository includes scripts for training and evaluation:

  • train.py: Main training script
  • eval.py: Evaluation script
  • generate.py: Text generation script

Example training command:

python train.py

Learning Rate Schedule

The model uses an inverse square root learning rate schedule with linear warmup:

  • Max LR: 8e-3
  • Warmup: 1% of total training steps
  • Optimizer: AdamW (β₁=0.9, β₂=0.95)
  • Weight decay: 0.1 (applied only to 2D+ weight matrices)
  • Gradient clipping: 1.0 (global norm)

Results

Results from training on FineWeb-Edu.

tiny-moe-train-plot

Dataset

The repository includes a sample dataset (Panchatantra stories) for testing and demonstration purposes. You can replace it with your own dataset by following the format in dataloader.py.

License

This project is licensed under the GNU General Public License v3.0 - see the LICENSE file for details.

Citation

If you use this code in your research, please cite it as:

@software{tiny_moe2025,
  author = {Vikram Pawar},
  title = {Tiny MoE: A Minimal Mixture-of-Experts Language Model Implementation},
  year = {2025},
  publisher = {GitHub},
  url = {https://github.com/novastar53/tiny_moe}
}

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Acknowledgments

This implementation draws inspiration from and builds upon the work of several key projects:

Additional acknowledgments:

  • The JAX and Flax community for their excellent tools and support
  • The authors of the original MoE papers that laid the groundwork for this field

About

A tiny mixture of experts model

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors