Skip to content

Commit ba8b4b4

Browse files
committed
[Algorithm] Async GRPO
1 parent 16b70be commit ba8b4b4

File tree

31 files changed

+2037
-355
lines changed

31 files changed

+2037
-355
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ htmlcov/
4444
.coverage
4545
.coverage.*
4646
.cache
47+
.neptune
4748
nosetests.xml
4849
coverage.xml
4950
*.cover

docs/source/reference/llms.rst

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,29 @@ TorchRL offers a set of tools for LLM post-training, as well as some examples fo
1010
Collectors
1111
----------
1212

13-
TorchRL offers a specialized collector class (:class:`~torchrl.collectors.llm.LLMCollector`) that is tailored for LLM
13+
TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`) that are tailored for LLM
1414
use cases. We also provide dedicated updaters for some inference engines.
1515

16+
LLM Collectors allow to track the version of the policy, which is useful for some use cases.
17+
This is done by adding a :class:`~torchrl.envs.llm.transforms.PolicyVersion` transform to the environment, which is
18+
then incremented by the collector after each weight update. To do this, one either provides the stateful version of the
19+
transform, or a boolean to the collector constructor.
20+
21+
>>> from torchrl.envs.llm.transforms import PolicyVersion
22+
>>> from torchrl.collectors.llm import LLMCollector
23+
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
24+
>>> env = make_env() # place your code here
25+
>>> policy = make_policy() # place your code here
26+
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
27+
>>> # init the updater
28+
>>> collector.weight_updater.init(...)
29+
>>> # the version is incremented after each weight update
30+
>>> collector.update_policy_weights_(state_dict=...)
31+
>>> print(collector.policy_version_tracker.version)
32+
>>> # the policy version is written in the data
33+
>>> for data in collector:
34+
... print(data["policy_version"])
35+
1636
.. currentmodule:: torchrl.collectors.llm
1737

1838
.. autosummary::
@@ -21,6 +41,7 @@ use cases. We also provide dedicated updaters for some inference engines.
2141

2242
vLLMUpdater
2343
LLMCollector
44+
RayLLMCollector
2445

2546

2647
Data structures
@@ -182,6 +203,7 @@ transforms).
182203
MCPToolTransform
183204
BrowserTransform
184205
PythonInterpreter
206+
PolicyVersion
185207
TemplateTransform
186208
Tokenizer
187209
as_nested_tensor

sota-implementations/grpo/README.md

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ GRPO is a method for training language models using reinforcement learning, with
1111
- Automatic checkpointing
1212
- Comprehensive logging with Weights & Biases
1313
- Hydra configuration system
14+
- Asynchronous training support with Ray
1415

1516
## Installation
1617

@@ -34,7 +35,27 @@ export VLLM_USE_V1=0 # Required for vLLM compatibility
3435
- vLLM inference device
3536
- Reference model device
3637

37-
Devices can be controlled via the `training_model.devices`, `inference_model.devices` and `ref_model.devices` arguments.
38+
### Device Management
39+
40+
There are two ways to specify device allocation:
41+
42+
1. Using `num_devices` (Recommended):
43+
```bash
44+
train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
45+
```
46+
This approach automatically manages device allocation based on the training mode (sync/async) and prevents device conflicts.
47+
48+
2. Using `devices` (Manual):
49+
```bash
50+
train_model.devices=[0,1] ref_model.devices=[2,3] inference_model.devices=[4,5]
51+
```
52+
This approach requires manual device management and is more error-prone.
53+
54+
The `num_devices` approach is recommended as it:
55+
- Automatically handles device allocation
56+
- Works correctly in both sync and async modes
57+
- Prevents device conflicts between model components
58+
- Is more portable across different machine configurations
3859

3960
## Configuration
4061

@@ -46,10 +67,24 @@ The training configuration is managed through Hydra. There are two main configur
4667

4768
### Basic Training
4869

70+
There are two training modes available:
71+
72+
#### Synchronous Mode (Default)
4973
```bash
50-
python grpo.py
74+
VLLM_USE_V1=0 python sota-implementations/grpo/grpo.py train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
5175
```
5276

77+
#### Asynchronous Mode (Recommended)
78+
```bash
79+
VLLM_USE_V1=0 python sota-implementations/grpo/grpo-async.py train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
80+
```
81+
82+
The async mode offers better performance by:
83+
- Running data collection and optimization concurrently
84+
- More efficient GPU utilization
85+
- Reduced memory overhead
86+
- Better throughput
87+
5388
### Run with IFEval Config
5489

5590
```bash
@@ -63,7 +98,7 @@ python grpo.py --config-name grpo_ifeval
6398
python grpo.py env.dataset=ifeval
6499

65100
# Modify training parameters
66-
python grpo.py train.epochs=2 train.optimizer.lr=2e-5
101+
python grpo.py optimizer.lr=2e-5 optimizer.weight_decay=0.01
67102

68103
# Change model
69104
python grpo.py model.name=meta-llama/Llama-2-7b-hf
@@ -73,14 +108,16 @@ python grpo.py model.name=meta-llama/Llama-2-7b-hf
73108

74109
```bash
75110
# Learning rate sweep
76-
python grpo.py --multirun train.optimizer.lr=1e-4,1e-5,1e-6
111+
python grpo.py --multirun optimizer.lr=1e-4,1e-5,1e-6
77112

78113
# Multiple parameters
79114
python grpo.py --multirun \
80-
train.optimizer.lr=1e-4,1e-5 \
115+
optimizer.lr=1e-4,1e-5 \
81116
policy.kl_coef=0.01,0.1
82117
```
83118

119+
Don't forget to set the number of value of `train.total_dialog_turns` to a reasonable value!
120+
84121
## Monitoring
85122

86123
Training progress is logged to Weights & Biases with the following metrics:
@@ -91,10 +128,11 @@ Training progress is logged to Weights & Biases with the following metrics:
91128
- ESS (Effective Sample Size)
92129
- Loss metrics (objective, clip fraction, etc.)
93130
- Gradient norm
131+
- Throughput metrics (in async mode)
94132

95133
## Checkpointing
96134

97-
Checkpoints are saved every `logging.checkpoint_frequency` batches and contain:
135+
Checkpoints are saved every `train.checkpoint_frequency` steps and contain:
98136
- Model state
99137
- Optimizer state
100138
- Gradient scaler state (for mixed precision)
@@ -114,8 +152,9 @@ Checkpoints are saved every `logging.checkpoint_frequency` batches and contain:
114152
sota-implementations/grpo/
115153
├── config/
116154
│ └── grpo_gsm8k.yaml # Main configuration file
117-
│ └── grpo_ifeval.yaml # config file for IFEval task
118-
├── grpo.py # Training script
155+
│ └── grpo_ifeval.yaml # config file for IFEval task
156+
├── grpo.py # Synchronous training script
157+
├── grpo-async.py # Asynchronous training script
119158
├── grpo_utils.py # Utility functions
120159
└── README.md # This file
121160
```

sota-implementations/grpo/config/grpo_gsm8k.yaml

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
# @package _global_
12
defaults:
3+
- mode: async # Default to async mode, will be overridden by grpo.py
24
- _self_
35
- override hydra/hydra_logging: disabled
46
- override hydra/job_logging: disabled
@@ -17,10 +19,30 @@ model:
1719
name: Qwen/Qwen2.5-3B
1820
compile: false
1921

22+
# Base training configuration - will be merged with mode-specific settings
23+
train:
24+
# Fields defined in mode configs (async.yaml and sync.yaml)
25+
# mixed_precision: true # Whether to use mixed precision training
26+
# epochs: 1 # Number of training epochs
27+
# steps_per_batch: 32 # Number of steps per batch
28+
# total_dialog_turns: 1_000_000 # Total number of dialog turns to collect
29+
# optim_batch_size: 2 # Batch size for optimization
30+
# gradient_accumulation_steps: 1 # Number of gradient accumulation steps
31+
# kl_coef_in_loss: true # Whether to include KL coefficient in loss
32+
# sync: false # Default to async, will be overridden by mode configs
33+
# buffer_size: 128 # Size of replay buffer
34+
35+
# Fields used by both scripts but with different semantics
36+
checkpoint_frequency: 100 # Save checkpoint every N steps/batches
37+
38+
# Fields used only by grpo-async.py
39+
weight_update_frequency: 50 # Update policy weights every N steps
40+
logging_frequency: 10 # Log metrics every N steps
2041
# Training model configuration
2142
train_model:
2243
gradient_checkpointing: true # Enabled for memory efficiency
23-
devices: [0] # List of GPU devices to use for training
44+
num_devices: 1 # Number of devices to use
45+
devices: null # Will be computed by compute_device_allocation
2446
lora:
2547
enabled: true # Using LoRA for memory efficiency
2648
r: 8 # LoRA rank - controls capacity of adaptations
@@ -31,57 +53,63 @@ train_model:
3153
attn_implementation: sdpa # Using flash attention for memory efficiency
3254
torch_dtype: bfloat16
3355

34-
# Inference model configuration (vLLM)
56+
# Inference model configuration
3557
inference_model:
36-
devices: [1] # List of GPU devices to use for inference
37-
gpu_memory_utilization: 0.5
58+
num_devices: 1 # Number of devices to use
59+
devices: null # Will be computed by compute_device_allocation
60+
quantization:
61+
enabled: false # Enable 4-bit quantization for base model
62+
attn_implementation: sdpa # Using flash attention for memory efficiency
63+
torch_dtype: bfloat16
64+
gpu_memory_utilization: 0.5 # Limit GPU memory usage
3865
temperature: 0.8
3966
max_tokens: 1024
4067
include_stop_str_in_output: true
4168

4269
# Reference model configuration
4370
ref_model:
44-
devices: [2] # List of GPU devices to use for reference model
71+
gradient_checkpointing: false # Always false, no backprop
72+
num_devices: 1 # Number of devices to use
73+
devices: null # Will be computed by compute_device_allocation
74+
lora:
75+
enabled: true # Using LoRA for memory efficiency
76+
r: 8 # LoRA rank - controls capacity of adaptations
77+
alpha: 16 # LoRA alpha - scales the adaptations
78+
dropout: 0.1 # Dropout probability for LoRA layers
4579
quantization:
46-
enabled: false # Enable quantization for memory efficiency
47-
gradient_checkpointing: false # Not needed for reference model
48-
attn_implementation:
80+
enabled: false # Enable 4-bit quantization for base model
81+
attn_implementation: sdpa # Using flash attention for memory efficiency
4982
torch_dtype: bfloat16
5083

5184
# Policy configuration
5285
policy:
5386
kl_coef: 1e-2
5487

55-
# Training configuration
56-
train:
57-
epochs: 1
58-
# Number of dialog turns per batch. This is passed to the collector and buffer.
59-
# More steps do not consume more GPU memory, but it does affect the inference speed in
60-
# that in sync contexts the training node will need to wait for a batch to be completed
61-
# before starting the next one.
62-
steps_per_batch: 64
63-
# Total number of dialog turns to collect during training
64-
total_dialog_turns: 1_000_000
65-
# Number of batches to run in parallel. This determines the batch size passed to the optimizer.
66-
# More batches consume more GPU memory.
67-
optim_batch_size: 1
68-
# Number of gradient accumulation steps. This determines the number of steps to run before
69-
# updating the parameters.
70-
gradient_accumulation_steps: 4 # Increased for gradient accumulation
71-
# Whether to include the KL coefficient in the loss or in the environment reward.
72-
kl_coef_in_loss: true
73-
# Whether to use mixed precision.
74-
mixed_precision: true # Disable mixed precision since we're not using it
75-
optimizer:
76-
name: AdamW
77-
lr: 1e-5
78-
clip_grad_norm: 0.5
79-
88+
# Optimizer configuration
89+
optimizer:
90+
name: AdamW
91+
lr: 1e-5
92+
clip_grad_norm: 100.0
93+
weight_decay: 0.0
94+
# Ray configuration
95+
ray:
96+
init_config:
97+
num_cpus: 96 # Total available CPUs
98+
num_gpus: 8 # Explicitly set number of GPUs
99+
runtime_env:
100+
working_dir: "."
101+
collector_config:
102+
num_cpus: 48 # CPUs for inference and ref model (co-located)
103+
train_handler_config:
104+
num_cpus: 24 # Dedicated CPUs for training
105+
replay_buffer_config:
106+
num_cpus: 24 # CPUs for replay buffer
107+
num_gpus: 0.0 # No GPU needed for replay buffer
80108
# Logging configuration
81109
logging:
82-
checkpoint_dir: checkpoints
83-
experiment_name: null # auto-generated if null
84-
checkpoint_frequency: 10 # save every N batches
110+
experiment_name: null # Will be auto-generated if not provided
111+
checkpoint_dir: "checkpoints"
112+
checkpoint_frequency: 10 # Save checkpoint every N batches
85113

86114
hydra:
87115
run:

0 commit comments

Comments
 (0)