A modular training framework for fine-tuning language models with Group Relative Policy Optimization (GRPO), designed to work with the Atropos environment system.
Note: The configs/ directory contains YAML configuration files for the environment server (e.g., math_server_zero.py), not for the trainer itself. The trainer is configured via CLI arguments documented in the CLI Reference section.
example_trainer/
├── grpo.py # CLI entry point (also exposed as `atropos-grpo`)
├── run.py # Unified shared_vllm launcher (also exposed as `atropos-grpo-run`)
├── config.py # TrainingConfig Pydantic model (all hyperparameters)
├── cli.py # CLI argument parsing (modular, single source of truth)
├── api.py # Atropos API communication (registration, batch fetching)
├── data.py # Data fetching, preprocessing, logprob alignment
├── model.py # Model loading, CUDA IPC, tensor mapping (QKV/Gate fusion)
├── training.py # GRPO loss (importance sampling and clipping)
├── checkpointing.py # Save models & LoRA adapters (handles fused tensor unfusing)
├── vllm_manager.py # vLLM process lifecycle (launch, health, termination)
├── trainers.py # 4 training mode implementations + optimizer selection
├── vllm_api_server.py # Custom vLLM server with /generate endpoint + LoRA
├── vllm_patching/ # CUDA IPC patches for weight sharing + B200 GPU compatibility
│ └── patched_gpu_runner.py
└── configs/ # Environment server configuration examples
├── math_zero_shared.yaml # Config for math_server_zero.py (shared_vllm mode)
└── math_zero_lora.yaml # Config for math_server_zero.py (lora mode)
After pip install -e . from the repository root, you can launch with either:
python -m example_trainer.grpooratropos-grpopython -m example_trainer.runoratropos-grpo-run
1. Generate multiple responses to the same prompt
2. Score each response (reward)
3. Compute ADVANTAGE = reward - mean(rewards)
4. Train: increase probability of above-average responses
decrease probability of below-average responses
| Concept | What It Means |
|---|---|
| Advantage | How much better/worse than average a response was |
| Importance Sampling | Corrects for policy drift during training |
| Rollout Logprobs | Token-level inference_logprobs captured during rollout and used in ratio computation |
| Clipping | Limits update magnitude for stability |
Data Flow:
1. Environment generates prompts → calls vLLM → scores responses
2. Environment sends trajectories to run-api
3. Trainer fetches batches from run-api
4. Trainer updates model weights
5. Weight synchronization:
- shared_vllm: vLLM sees updates immediately via CUDA IPC (zero-copy)
- lora_only: Trainer pushes adapter to vLLM via HTTP (slow)
- lora_restart: Trainer restarts vLLM with new adapter (fast)
- none (legacy): Trainer saves checkpoint and restarts vLLM
| Mode | Description | Memory | Inference Speed | Best For |
|---|---|---|---|---|
| shared_vllm | Single-copy via CUDA IPC | 1x model | ~172 TPS | Same GPU, maximum efficiency |
| lora_restart | LoRA + vLLM restarts | 1x + adapter | ~108 TPS | LoRA training with speed |
| lora_only | LoRA + HTTP hot-swap | 1x + adapter | ~13 TPS |
Debugging only |
| none (legacy) | Full model, restart vLLM | 2x model | ~172 TPS | simple setup |
The lora_only mode requires --enforce-eager which disables CUDA graphs, resulting in:
- 8x slower inference (~13 TPS vs ~108 TPS)
- Training that takes 4x longer (401 min vs 132 min for 120 steps)
Use lora_restart instead - it runs vLLM without --enforce-eager for 8x faster inference.
Use shared_vllm for production training when:
- You have enough GPU memory for the full model
- You want fastest training (no overhead)
- Trainer and vLLM are on the same GPU(s)
Use lora_restart when:
- You want LoRA's memory efficiency
- You can tolerate ~45s restart overhead every N steps
Avoid lora_only unless you're debugging - the 8x inference penalty is severe.
Use none (legacy) mode when:
- You want the simplest setup without CUDA IPC or LoRA
- Install from
pyproject.tomlextras:pip install -e ".[example_trainer]"- or everything:
pip install -e ".[all]"
Terminal 1: API Server
run-api --port 8002Terminal 2: vLLM Server
python -m example_trainer.vllm_api_server \
--model NousResearch/Hermes-3-Llama-3.1-8B \
--port 9001 \
--gpu-memory-utilization 0.5 \
--max-model-len 4096 \
--dtype bfloat16 \
--enable-lora \
--enforce-eagerTerminal 3: Environment
# Important: Use server_type=vllm to get logprobs (required for GRPO)
python environments/gsm8k_server.py serve \
--env.group_size 4 \
--env.batch_size 16 \
--env.total_steps 200 \
--env.steps_per_eval 50 \
--env.max_num_workers_per_node 8 \
--env.rollout_server_url "http://localhost:8002" \
--env.use_wandb true \
--env.wandb_name "gsm8k-lora-only-env" \
--openai.api_key "dummy" \
--openai.base_url "http://localhost:9001/v1" \
--openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \
--openai.server_type vllmTerminal 4: Trainer
python -m example_trainer.grpo \
--model-name NousResearch/Hermes-3-Llama-3.1-8B \
--weight-bridge-mode lora_only \
--vllm-port 9001 \
--atropos-url "http://localhost:8002" \
--batch-size 4 \
--gradient-accumulation-steps 4 \
--warmup-steps 20 \
--lr 1e-5 \
--training-steps 30 \
--clip-eps 0.2 \
--vllm-restart-interval 5 \
--save-path ./lora_checkpoints \
--wandb-project "grpo-training"# Follow this startup order
# 1. Start API first
run-api --port 8002
# 2. Wait 5s, then start vLLM
# Check health: curl http://localhost:9001/health
python -m example_trainer.vllm_api_server --model ... --enable-lora --enforce-eager
# 3. Wait for vLLM health endpoint to return 200
while ! curl -s http://localhost:9001/health > /dev/null; do sleep 1; done
# 4. Start environment (use --openai.server_type vllm for logprobs)
python environments/gsm8k_server.py serve \
--env.group_size 4 \
--env.batch_size 16 \
--env.total_steps 200 \
--env.steps_per_eval 50 \
--env.max_num_workers_per_node 8 \
--env.rollout_server_url "http://localhost:8002" \
--env.use_wandb true \
--env.wandb_name "gsm8k-train-env" \
--openai.base_url "http://localhost:9001/v1" \
--openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \
--openai.server_type vllm
# 5. Start trainer (will register with API and begin training)
python -m example_trainer.grpo --weight-bridge-mode lora_only ...Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication!
┌─────────────────────────────────────────────────────────────────────┐
│ SINGLE GPU (CUDA IPC) │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Model Weights (ONE copy in GPU memory) │ │
│ │ (accessible via CUDA IPC handles) │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ ▲ ▲ │
│ │ Reads (inference) │ Writes │
│ ┌────────┴────────┐ ┌───────────┴───────────┐ │
│ │ vLLM Worker │ │ Trainer Process │ │
│ │ │ │ (attached via IPC) │ │
│ └─────────────────┘ └───────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
Terminal 1: API
run-api --port 8002Terminal 2: vLLM with Shared Weights
VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=/tmp/grpo_training \
python -m example_trainer.vllm_api_server \
--model NousResearch/Hermes-3-Llama-3.1-8B \
--port 9001 \
--gpu-memory-utilization 0.45 \
--enforce-eagerTerminal 3: Environment
# Important: Use server_type=vllm to get logprobs (required for GRPO)
python environments/gsm8k_server.py serve \
--openai.base_url "http://localhost:9001/v1" \
--openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \
--openai.server_type vllm \
--env.group_size 4 \
--env.batch_size 16 \
--env.total_steps 200 \
--env.steps_per_eval 50 \
--env.max_num_workers_per_node 8 \
--env.rollout_server_url "http://localhost:8002" \
--env.use_wandb true \
--env.wandb_name "gsm8k-shared-vllm-env"Terminal 4: Trainer
python -m example_trainer.grpo \
--model-name NousResearch/Hermes-3-Llama-3.1-8B \
--weight-bridge-mode shared_vllm \
--vllm-port 9001 \
--vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \
--atropos-url "http://localhost:8002" \
--warmup-steps 20 \
--clip-eps 0.2# Single command starts both vLLM and trainer
VLLM_ENABLE_SHARED_WEIGHTS=1 python -m example_trainer.run \
--model-name NousResearch/Hermes-3-Llama-3.1-8B \
--atropos-url "http://localhost:8002" \
--training-steps 30For this example trainer implementation, set --openai.server_type vllm so the
environment uses the /generate path and includes token-level
inference_logprobs in the trajectory payload consumed by the trainer.
# gets logprobs for training
--openai.server_type vllm
# does NOT return rollout inference_logprobs — trainer will error
--openai.server_type openaiHow logprobs flow through the system:
- Environment calls vLLM
/generatewithlogprobs=true - vLLM returns token-level logprobs for each generated token
- Environment embeds these in trajectory data sent to API
- Trainer extracts and aligns logprobs with training labels
- GRPO loss uses these rollout logprobs in importance-ratio terms
When distillation data is attached to Atropos batches, the trainer treats
distill_token_ids as indices into the student's logit tensor. That only works
if the teacher and student share the same tokenizer vocabulary.
What to configure on the environment side:
--env.teacher_enabled true \
--teacher.base_url "http://localhost:9003/v1" \
--teacher.model_name "$TEACHER_MODEL" \
--teacher.server_type vllm \
--env.teacher_top_k 8If $TEACHER_MODEL is a deployment alias instead of a tokenizer identifier,
also set --teacher.tokenizer_name ... so the env can validate
tokenizer compatibility.
The teacher-aware CLI path is currently wired for serve. If
teacher_enabled=True, the generic process and evaluate commands are not
teacher-aware and will fail loudly unless the environment is instantiated
manually with teacher_server_configs=....
Why cross-tokenizer conversion is not acceptable here:
- Teacher token ID
1234and student token ID1234can correspond to different text. - Re-tokenizing teacher text changes token boundaries, so teacher position
imay no longer correspond to student positioni. - Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens.
- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets.
--clip-eps 0.2 # Limits importance sampling ratio to [0.8, 1.2]Symptoms of missing/misconfigured clipping:
- Accuracy drops dramatically (e.g., 59% → 7%)
- Loss goes to very negative values (< -10)
- Model outputs become repetitive/degenerate
mean_ratiodiverges far from 1.0
For background on clipping and importance sampling, see https://fengyao.notion.site/off-policy-rl
Use a short linear warmup when training from fresh runs or small batch settings:
--warmup-steps 20This linearly ramps learning rate from 0 to --lr over the first N optimizer steps.
Healthy training metrics:
mean_ratio: 0.8 - 1.2 (close to 1.0)clipped_fraction: < 0.3 (< 30% of tokens clipped)
| Model Size | GPU Memory | Recommended Settings |
|---|---|---|
| 8B | 80GB | --gpu-memory-utilization 0.5 |
| 14B | 80GB | --gpu-memory-utilization 0.45, --batch-size 2 |
| 24B | 192GB (B200) | --gpu-memory-utilization 0.30, --optimizer adafactor |
🔧 B200/Blackwell GPU Support:
The trainer includes automatic patches for NVIDIA B200 (Blackwell architecture) GPUs when using LoRA mode. These patches disable Grid Dependency Control (GDC) in vLLM's Triton kernels, which causes compilation failures on Blackwell GPUs. The patches are applied automatically when:
VLLM_ENABLE_SHARED_WEIGHTS=1is set, orNUM_INFERENCE_NODESis set (distributed inference path)
The patching clears the Triton cache and disables GDC to ensure compatibility. No manual intervention required.
The trainer supports multiple optimizer options to trade off between speed, memory, and precision:
| Optimizer | GPU Memory for States | Speed | Precision | Dependencies |
|---|---|---|---|---|
adamw |
Highest | Fastest | Full FP32 | None |
adamw_8bit (default) |
Lower | Fast | 8-bit quantized | bitsandbytes |
adafactor |
Lower | Fast | Full (no momentum) | transformers |
Usage:
# 8-bit AdamW (default) - recommended for memory-constrained setups
--optimizer adamw_8bit
# Standard AdamW - full precision
--optimizer adamw
# Adafactor - no momentum states, good for large models
--optimizer adafactorRecommendations:
- 8B models on 80GB: Use
adamw(fastest) - 14B+ models on 80GB: Use
adamw_8bitoradafactor - 24B models: Use
adafactorwith reduced batch size
Potential Risks:
adamw_8bit: Quantization may slightly affect convergence in edge cases; generally safeadafactor: No momentum can make training slightly less stable; use with larger batch sizes
vLLM fuses certain layers for efficiency, but HuggingFace keeps them separate:
HuggingFace Model: vLLM Model:
├── q_proj [4096, 4096] ├── qkv_proj [12288, 4096] ← FUSED!
├── k_proj [1024, 4096] │ (contains q, k, v concatenated)
├── v_proj [1024, 4096] │
│ │
├── gate_proj [14336, 4096] ├── gate_up_proj [28672, 4096] ← FUSED!
├── up_proj [14336, 4096] │ (contains gate and up concatenated)
The trainer creates views into vLLM's fused tensors:
# vLLM has: qkv_proj.weight [12288, 4096]
# We need: q_proj [4096], k_proj [1024], v_proj [1024]
# Get sizes from model config
q_size = num_heads * head_dim # e.g., 4096
k_size = num_kv_heads * head_dim # e.g., 1024
v_size = num_kv_heads * head_dim # e.g., 1024
# Create views (no copy!)
hf_model.q_proj.weight = vllm_qkv[0:4096, :] # First chunk
hf_model.k_proj.weight = vllm_qkv[4096:5120, :] # Second chunk
hf_model.v_proj.weight = vllm_qkv[5120:6144, :] # Third chunk# These point to the SAME GPU memory:
trainer_q_proj.data_ptr() == vllm_qkv_proj.data_ptr() # True!
# So when optimizer updates trainer weights:
optimizer.step() # Updates trainer_q_proj
# vLLM sees the change immediately (same memory)!vLLM exports tensor mappings to vllm_bridge_config.json:
{
"model": "NousResearch/Hermes-3-Llama-3.1-8B",
"param_mappings": {
"model.layers.0.self_attn.qkv_proj.weight": {
"ipc_handle": "base64_encoded_cuda_ipc_handle",
"shape": [6144, 4096],
"dtype": "bfloat16"
}
}
}A: Look for these log messages during training:
[WARNING] ref_logprobs at generated positions avg 0.85 (should be negative!)
[WARNING] This suggests inference_logprobs alignment is wrong
This means inference logprobs aren't being passed correctly. Debug steps:
-
Check environment server type:
# Must be 'vllm', NOT 'openai' --openai.server_type vllm -
Verify vLLM returns logprobs:
curl -X POST http://localhost:9001/generate \ -H "Content-Type: application/json" \ -d '{"prompt": "Hello", "max_tokens": 5}' # Response should include "logprobs": [...]
-
Check data.py logs:
[Data] ✓ inference_logprobs found in batch (sample len: 128) -
Monitor alignment metrics in training logs:
alignment/diff_meanshould be close to 0 at step startalignment/diff_abs_mean< 0.1 = good alignment- Large values = weights not properly shared or logprobs misaligned
# Start the API server first
run-api --port 8002You're using a vLLM server that doesn't expose /generate. Use our custom server:
python -m example_trainer.vllm_api_server ... # Has /generate
# NOT: python -m vllm.entrypoints.openai.api_server # Only has /v1/*vLLM v1 engine issue. We disable it by default, but if you see this:
VLLM_USE_V1=0 python -m example_trainer.vllm_api_server ...This warning appears during training when inference logprobs alignment is incorrect. Weight updates may not be visible to inference. Fix:
# Add --enforce-eager to vLLM
python vllm_api_server.py --model $MODEL --enforce-eagerYou may also see related alignment warnings:
[WARNING] This suggests inference_logprobs alignment is wrong
[DEBUG] Logprob gap: ref=X.XXX, train=X.XXX
Reduce memory usage:
--gpu-memory-utilization 0.4 # Less vLLM memory
--batch-size 2 # Smaller batches
--gradient-accumulation-steps 8 # Compensate with accumulation
--seq-len 1024 # Shorter sequences
--optimizer adafactor # Uses less memory than AdamWvLLM version incompatibility. Our server handles this automatically, but make sure you're using:
python -m example_trainer.vllm_api_server # NOT direct vllm commands--use-wandb \
--wandb-project "my-grpo-training" \
--wandb-group "hermes-8b-gsm8k"| Argument | Default | Description |
|---|---|---|
--model-name or --model |
(required) | HuggingFace model ID |
--weight-bridge-mode |
none |
shared_vllm, lora_only, lora_restart, or none |
--training-steps |
10 | Number of training steps |
--checkpoint-interval |
3 | Save checkpoint every N steps (0 = final only) |
--batch-size |
2 | Micro-batch size |
--gradient-accumulation-steps |
32 | Effective batch = batch × accum |
--warmup-steps |
0 | Linear LR warmup steps (0 disables warmup) |
--seq-len |
2048 | Maximum sequence length |
--train-layer-indices |
None | Optional full-model layer filter for shared/legacy modes (examples: 20-31, 0-3,28-31) |
| Argument | Default | Description |
|---|---|---|
--clip-eps |
0.2 | PPO clipping range [1-ε, 1+ε] |
--lr |
1e-5 | Learning rate (NOT --learning-rate) |
| Argument | Default | Description |
|---|---|---|
--lora-r |
16 | LoRA rank (dimension of low-rank matrices) |
--lora-alpha |
32 | LoRA alpha scaling factor |
--lora-dropout |
0.05 | LoRA dropout probability |
--lora-target-modules |
None | Module names to apply LoRA (None falls back to q_proj v_proj) |
--lora-layer-indices |
None | Optional layer filter (examples: 20-31, 0-3,28-31) |
Layer-index arguments are model-dependent (--train-layer-indices for full/shared modes, --lora-layer-indices for LoRA modes). Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
| Architecture family | Common config fields | Typical layer list path | Notes |
|---|---|---|---|
| LLaMA / Llama-2 / Llama-3 / Mistral | num_hidden_layers |
model.layers |
Most common causal-LM layout |
| Qwen / Qwen2 / Qwen2.5 / Qwen3 | num_hidden_layers |
model.layers |
Similar layer indexing to LLaMA |
| GPT-2 / GPT-J style | n_layer or mapped to num_hidden_layers |
transformer.h |
PEFT may use h pattern internally |
| Falcon | num_hidden_layers |
transformer.h |
Uses h block list in model module tree |
Always query the model config before choosing indices:
python - <<'PY'
from transformers import AutoConfig
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
cfg = AutoConfig.from_pretrained(model_id)
num_layers = getattr(cfg, "num_hidden_layers", None)
if num_layers is None:
num_layers = getattr(cfg, "n_layer", None)
print(f"model={model_id}")
print(f"num_hidden_layers={num_layers}")
if num_layers is not None:
print(f"valid index range: 0-{num_layers-1}")
PYIf your model has N layers:
- Full layers: omit
--train-layer-indices - Top 25%:
--train-layer-indices {int(0.75*N)}-{N-1} - Top 50%:
--train-layer-indices {int(0.5*N)}-{N-1} - Last 12 layers:
--train-layer-indices {N-12}-{N-1}(ifN >= 12)
| Argument | Default | Description |
|---|---|---|
--vllm-port |
9001 | vLLM server port |
--vllm-config-path |
auto | Path to bridge config (shared mode) |
--gpu-memory-utilization |
0.45 | vLLM GPU memory fraction |
--vllm-gpu |
None | GPU ID for vLLM (None = same as trainer) |
--max-model-len |
4096 | Maximum context length |
--dtype |
bfloat16 |
Model dtype: bfloat16, float16, or auto |
--vllm-restart-interval |
3 | Restart vLLM every N steps (legacy/lora_restart) |
| Argument | Default | Description |
|---|---|---|
--atropos-url |
http://localhost:8000 |
URL of the Atropos API server |
Note: Many examples in this README use http://localhost:8002 because they start run-api --port 8002.
| Argument | Default | Description |
|---|---|---|
--use-wandb |
False | Enable W&B logging |
--wandb-project |
None | W&B project name |
--wandb-group |
None | W&B group name (auto-generated if omitted) |
| Argument | Default | Description |
|---|---|---|
--trainer-rank |
0 | Trainer rank |
--world-size |
1 | World size |
--init-method |
env:// |
Distributed init method |
--num-inference-nodes |
0 | Number of inference nodes |
| Argument | Default | Description |
|---|---|---|
--debug-loading |
False | Verbose model loading diagnostics |
--benchmark |
False | Print benchmark/timing metrics |
--log-dir |
./logs |
Directory for unified launcher logs |
| Module | Purpose |
|---|---|
grpo.py |
CLI entry point, dispatches to training modes (4 modes) |
run.py |
Unified launcher for shared_vllm mode (starts vLLM + trainer) |
cli.py |
Single source of truth for all CLI arguments (modular builders) |
config.py |
TrainingConfig Pydantic model with all hyperparameters |
api.py |
Communication with Atropos API (registration, batch fetching) |
data.py |
Batch preprocessing, padding, logprob extraction and alignment |
model.py |
Model loading, CUDA IPC attachment, tensor mapping (QKV/Gate fusion) |
training.py |
GRPO loss computation (importance sampling and clipping) |
trainers.py |
Mode-specific training loops (4 implementations + optimizer selection) |
vllm_api_server.py |
Custom vLLM server with /generate endpoint and LoRA support |
vllm_manager.py |
vLLM process lifecycle management (launch, health checks, termination) |
checkpointing.py |
Save/load checkpoints and adapters (handles fused tensor unfusing) |
1. CLI Parsing (cli.py)
↓
2. Config Creation (config.py)
↓
3. Mode Dispatcher (grpo.py or run.py)
↓
4. Trainer Function (trainers.py)
├─ Setup Phase
│ ├─ Initialize W&B (training.py)
│ ├─ Load Model (model.py)
│ ├─ Create Optimizer (trainers.py)
│ ├─ Check Atropos API (api.py)
│ ├─ Register Trainer (api.py)
│ └─ Launch/Connect vLLM (vllm_manager.py or external)
│
└─ Training Loop
├─ Fetch Batch (api.py → data.py)
│ ├─ Poll /batch endpoint
│ ├─ Pad sequences (data.py)
│ ├─ Extract inference logprobs (data.py)
│ └─ Normalize advantages (data.py)
│
├─ Training Step (training.py)
│ ├─ For each micro-batch:
│ │ ├─ Forward pass (model)
│ │ ├─ Compute GRPO loss (training.py)
│ │ │ ├─ Temperature scaling
│ │ │ ├─ Compute log probabilities
│ │ │ ├─ Importance sampling ratio (using inference logprobs)
│ │ │ ├─ PPO clipping
│ │ │ └─ Return loss + metrics
│ │ └─ Backward pass (accumulate gradients)
│ ├─ Clip gradients (norm=1.0)
│ ├─ Optimizer step
│ └─ Zero gradients
│
├─ Weight Sync (mode-dependent)
│ ├─ shared_vllm: No sync needed (weights shared via CUDA IPC)
│ ├─ lora_only: HTTP POST to /lora/load
│ ├─ lora_restart: Save adapter + terminate + relaunch vLLM
│ └─ none: Save checkpoint + terminate + relaunch vLLM
│
├─ Log Metrics (training.py)
│ ├─ Console output
│ └─ W&B logging (if enabled)
│
└─ Periodic Checkpoint (checkpointing.py)
├─ Ensure tensors are contiguous (unfuse views)
├─ Save state dict
└─ Free GPU memory
# Entry: grpo.py → trainers.train_shared_vllm()
1. Model Loading (model.py):
- Find vllm_bridge_config.json
- Load IPC handles (CUDA memory pointers)
- Create empty model on meta device
- Reconstruct tensors from IPC handles
- Map vLLM fused tensors → HF unfused parameters
* qkv_proj → q_proj, k_proj, v_proj (views)
* gate_up_proj → gate_proj, up_proj (views)
- Initialize remaining meta tensors (buffers, etc.)
2. Training Loop:
- optimizer.step() directly modifies vLLM's tensors
- No weight synchronization needed!
- Checkpoints: Unfuse views before saving (checkpointing.py)
3. Tensor Mapping (model.py:_create_vllm_to_hf_mapping):
- Reads actual HF tensor shapes from model.state_dict()
- Creates slice mappings for fused layers
- Example: q_proj = qkv_proj[0:4096, :]# Entry: grpo.py → trainers.train_lora_restart()
1. Model Loading (model.py):
- Load base model with PEFT
- Apply LoRA config to target modules
- Freeze base weights, only LoRA trainable
2. vLLM Management:
- Launch: _launch_vllm_with_lora()
* NO --enforce-eager flag (CUDA graphs enabled)
* Pre-load initial adapter
- Periodic Restart:
* Save new adapter (checkpointing.py)
* Terminate vLLM aggressively (_terminate_vllm)
- Kill process group
- Kill by port (fuser)
- Kill by process name patterns
- Wait for GPU memory release (critical!)
* Relaunch with new adapter
3. Performance:
- ~108 TPS (CUDA graphs enabled)
- ~45s restart overhead
- Much faster than lora_only (~8x speedup)# Entry: grpo.py → trainers.train_lora()
1. Model Loading: Same as lora_restart
2. vLLM: External server (must be pre-started)
- MUST use --enforce-eager (disables CUDA graphs)
- MUST use --enable-lora
3. Weight Sync: _hotswap_lora_adapter()
- Tries /v1/load_lora_adapter (native vLLM)
- Falls back to /lora/load (custom endpoint)
4. Performance:
- ~13 TPS (CUDA graphs disabled)
- No restart overhead
- 8x slower than lora_restart!# Entry: grpo.py → trainers.train_legacy()
1. Model Loading: Full model (model.py)
2. vLLM Management:
- Launch: vllm_manager.launch_vllm_server()
- Periodic Restart:
* Save full checkpoint (checkpointing.py)
* Terminate vLLM (vllm_manager.terminate_vllm_process)
* Relaunch with new checkpoint
3. Use Case:
- Different GPUs for trainer and vLLM
- Simple setup without CUDA IPC or LoRA# api.get_batch() → data.get_data() → data.pad_data_to_good_offset()
1. Batch Structure from API:
{
"batch": [
{
"tokens": [[tok1, tok2, ...], ...], # group_size sequences
"masks": [[mask1, mask2, ...], ...], # -100 for prompt, token_id for generated
"scores": [score1, score2, ...], # rewards
"inference_logprobs": [[lp1, lp2, ...], ...], # required for this GRPO trainer
"generation_params": {"temperature": 1.0},
...
}
]
}
2. Preprocessing (pad_data_to_good_offset):
- Normalize advantages (mean=0, std=1 per group)
- Pad sequences to multiple of 64
- Align inference_logprobs with labels:
* 1.0 for prompt tokens (masked)
* Actual negative logprobs for generated tokens
* Shift by 1 for causal alignment
- Extract temperatures (priority: override > generation_params > 1.0)
- Batch into micro-batches
3. Output:
- token_batches: [B, seq_len]
- label_batches: [B, seq_len] # -100 for masked
- advantage_batches: [B, 1]
- temperature_batches: [B, 1, 1]
- inference_logprob_batches: [B, seq_len] # aligned with labels!# training.compute_grpo_loss()
1. Forward Pass:
- Get logits from model
- Apply temperature scaling (from data)
- Compute log probabilities per token
2. Rollout Logprobs:
- Extract from inference_logprobs (from vLLM at generation time)
- Already aligned with labels by data.py
3. Importance Sampling:
- log_ratio = current_logprob - rollout_inference_logprob
- ratio = exp(log_ratio)
- Clipped ratio = clip(ratio, 1-ε, 1+ε)
4. Policy Loss:
- surr1 = ratio * advantage
- surr2 = clipped_ratio * advantage
- policy_loss = -min(surr1, surr2) # pessimistic bound
5. Total Loss:
- loss = policy_loss
- Scaled by 1/gradient_accumulation_steps
6. Metrics:
- mean_ratio: Average importance sampling ratio
- clipped_fraction: % of tokens clipped
- alignment/* : Token-level logprob alignment (verifies weight sharing)For algorithm background and design tradeoffs, see: