Skip to content

πŸ¦€ GPU-Accelerated ML Training Framework in Pure Rust - wgpu/Metal

License

Notifications You must be signed in to change notification settings

puranikyashaswin/rusty

Repository files navigation

Rust Metal Performance License

Rusty

GPU-Accelerated ML Framework in Pure Rust

From custom GPU kernels to Llama architecture β€” understanding ML at every layer

Quick Start β€’ Features β€’ Architecture β€’ Benchmarks β€’ Examples


Overview

Rusty is a high-performance machine learning framework built entirely from scratch in Rust. It provides GPU-accelerated inference and training with a focus on transformer architectures like Llama.

What makes this project unique:

  • Custom GPU Kernels β€” Hand-written WGSL compute shaders, not relying on cuBLAS or external libraries
  • Complete Llama Architecture β€” Multi-head attention with RoPE, SwiGLU MLP, RMSNorm, KV cache
  • LoRA Fine-tuning β€” Parameter-efficient training on consumer hardware
  • Cross-Platform GPU β€” Runs on Metal (Apple Silicon) and Vulkan (Windows/Linux)

Quick Start

Run a GPU demo in 30 seconds:

git clone https://github.com/puranikyashaswin/rusty.git
cd rusty
cargo run --example basic_tensor --release -p rusty

Expected output:

Rusty ML - Basic Tensor Example

[GPU] Apple M2 (Metal)

[INIT] Creating tensors...
       Tensor A: [32, 32]
       Tensor B: [32, 32]

[COMPUTE] Performing matrix multiplication...
          Result shape: [32, 32]
          First few values: [10.65, 10.72, 10.79, 10.87, 10.94]

[DONE] All operations completed successfully!

Key Features

GPU Compute Engine

Custom WGSL compute shaders optimized for ML workloads:

Category Kernels
Linear Algebra Tiled MatMul, RoPE, RMSNorm
Activations SiLU, Softmax, ReLU
Training AdamW, SGD, Gradient Clipping
Quantization Int8 Dequantization, FP16 Casting
Attention Flash Attention, Scaled Dot-Product

Neural Network Layers

Production-ready building blocks:

  • Embedding β€” Token embedding with vocabulary lookup
  • Linear β€” Dense layers with optional LoRA adapters
  • Attention β€” Multi-head attention with rotary embeddings
  • MLP β€” SwiGLU feedforward network
  • LlamaBlock β€” Complete transformer block
  • LlamaModel β€” Full model with generation support

Training Infrastructure

  • Automatic differentiation with gradient tape
  • GPU-accelerated AdamW optimizer
  • Mixed precision training (FP16)
  • Gradient accumulation and clipping
  • LoRA for parameter-efficient fine-tuning

Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         rusty-cli                               β”‚
β”‚                     Command Line Interface                      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚      rusty-trainer      β”‚      rusty-loader                     β”‚
β”‚      Training Loops     β”‚      Safetensors + Tokenizer          β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                       rusty-graph                               β”‚
β”‚            Neural Networks: Attention, MLP, Llama               β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                      rusty-autograd                             β”‚
β”‚             Automatic Differentiation + Optimizers              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                      rusty-backend                              β”‚
β”‚              GPU Compute Engine + WGSL Kernels                  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                     Metal / Vulkan                              β”‚
β”‚                   Apple M1/M2/M3, GPUs                          β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Crate Overview

Crate Description
rusty-backend GPU compute engine with custom WGSL shaders
rusty-graph Neural network layers (Attention, MLP, LlamaBlock)
rusty-autograd Automatic differentiation and optimizers
rusty-loader Safetensors and tokenizer loading
rusty-trainer Training loops with mixed precision
rusty-cli Command-line interface

Performance

Benchmarked on Apple M2 (Metal backend):

Operation Size Throughput
MatMul 4096Γ—4096 121 GFLOPS
MatMul 2048Γ—2048 114 GFLOPS
MatMul 1024Γ—1024 109 GFLOPS
Softmax 2048Γ—2048 850M elem/s
RMSNorm 4096Γ—2048 920M elem/s
Run benchmarks
cargo run -p benchmarks --release

Examples

Basic GPU Operations

cargo run --example basic_tensor --release -p rusty

Matrix multiplication, element-wise operations, and activations on GPU.

Flash Attention

cargo run --example flash_attention --release -p rusty

Memory-efficient attention with O(N) memory instead of O(NΒ²).

LoRA Fine-tuning

cargo run --example lora_finetune --release -p rusty

Parameter-efficient training with low-rank adapters.

Training Demo

cargo run -p rusty-cli --release -- --demo

Complete training loop with loss computation.


Fine-tuning Models

Step 1: Download a Model

./scripts/download_model.sh tinyllama

Or manually:

pip install huggingface-hub
huggingface-cli download TinyLlama/TinyLlama-1.1B-Chat-v1.0 --local-dir ./models/tinyllama

Step 2: Prepare Training Data

[
  {"prompt": "Who are you?", "response": "I am an AI assistant."},
  {"prompt": "What can you do?", "response": "I can answer questions and assist with tasks."}
]

Step 3: Fine-tune

cargo run -p rusty-cli --release -- ./models/tinyllama ./data/train.json

Supported Models

Model Status
LLaMA / LLaMA-2 / LLaMA-3 βœ“ Supported
TinyLlama βœ“ Supported
Mistral βœ“ Supported
Phi / Phi-2 / Phi-3 βœ“ Supported
Qwen / Qwen-2 βœ“ Supported
Gemma / Gemma-2 βœ“ Supported

Requirements

  • Rust 1.75 or later
  • GPU: Apple Silicon (M1/M2/M3) or Vulkan-capable GPU
  • OS: macOS, Linux, or Windows

Project Status

Component Status
GPU Backend βœ“ Complete
Custom WGSL Kernels βœ“ Complete
Llama Architecture βœ“ Complete
LoRA Fine-tuning βœ“ Complete
Autograd + Optimizers βœ“ Complete
Safetensors Loading βœ“ Complete
Mixed Precision (FP16) βœ“ Complete
Flash Attention βœ“ Complete
CUDA Backend Planned
Distributed Training Planned

Contributing

Contributions are welcome. See CONTRIBUTING.md for guidelines.

git clone https://github.com/YOUR_USERNAME/rusty.git
git checkout -b feature/your-feature
cargo test --workspace
git commit -m "feat: description"
git push origin feature/your-feature

License

MIT License β€” see LICENSE.


Acknowledgments


Built with Rust

About

πŸ¦€ GPU-Accelerated ML Training Framework in Pure Rust - wgpu/Metal

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published