Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions apps/grpo/qwen3_30b_a3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Grouped Relative Policy Optimization (GRPO)
# >>> python -m apps.grpo.main --config apps/grpo/qwen32b.yaml
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability
# Global configuration
group_size: 16
local_batch_size: 2 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default
provisioner:
launcher: slurm
# Main loop configuration
rollout_threads: 4 # make this 4x the number of policy replicas seems to work well
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
console:
reduce_across_ranks: True
# Dataset configuration
dataset:
path: "openai/gsm8k"
revision: "main"
data_split: "train"
streaming: true
model: ${model}
# Policy configuration
policy:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: false
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
# Trainer configuration
trainer:
model:
name: qwen3
flavor: 32B
hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 1e-5
eps: 1e-8
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true
last_save_in_hf: true
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: full
# Replay buffer configuration
replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 32
# Reference model configuration
ref_model:
model:
name: qwen3
flavor: 32B
hf_assets_path: hf://${model}
training:
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 4
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true
# All resource allocations
services:
policy:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
hosts: 1
with_gpus: true
mesh_name: policy
ref_model:
procs: ${ref_model.parallelism.tensor_parallel_degree}
num_replicas: 1
with_gpus: true
mesh_name: ref_model
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
mesh_name: reward_actor
actors:
dataset:
procs: 1
with_gpus: false
mesh_name: dataset
trainer:
procs: 8
hosts: 2
with_gpus: true
mesh_name: trainer
replay_buffer:
procs: 1
with_gpus: false
mesh_name: replay_buffer
compute_advantages:
procs: 1
with_gpus: false
mesh_name: compute_advantages
14 changes: 7 additions & 7 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# Global configuration
group_size: 16
local_batch_size: 32 # per-device batch size
local_batch_size: 2 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
Expand All @@ -14,7 +14,7 @@ provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 32 # make this 4x the number of policy replicas seems to work well
rollout_threads: 4 # make this 4x the number of policy replicas seems to work well

# Observability configuration
metric_logging:
Expand Down Expand Up @@ -69,8 +69,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -90,7 +90,7 @@ replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 1
dp_size: 64

# Reference model configuration
ref_model:
Expand Down Expand Up @@ -119,7 +119,7 @@ ref_model:
services:
policy:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 4
num_replicas: 1
hosts: 1
with_gpus: true
mesh_name: policy
Expand All @@ -141,7 +141,7 @@ actors:
mesh_name: dataset
trainer:
procs: 8
hosts: 1
hosts: 4
with_gpus: true
mesh_name: trainer
replay_buffer:
Expand Down
Loading