Skip to content

Commit 7619b79

Browse files
authored
Changes for Dr. GRPO (#35)
1 parent 59eb01b commit 7619b79

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

README.md

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
---
1313

1414
## Updates
15+
* 21/03/2025: We incorporate [Dr. GRPO](https://github.com/sail-sg/understand-r1-zero), which fixes the optimization bias in GRPO.
1516
* 26/01/2025: We support reinforcement learning with verifiable rewards (RLVR) for math reasoning.
1617
* A quick [example](https://github.com/sail-sg/oat/blob/main/docs/reasoning_examples.md#deepseek-r1-zero-like-training) of R1-Zero-like training with GRPO.
17-
18+
* 20/10/2024: We open source Oat, an online LLM alignment framework developed during a research project on online LLM exploration ([sample-efficient alignment](https://arxiv.org/pdf/2411.01493)).
1819
## Introduction
1920

2021
Oat 🌾 is a simple yet efficient framework for running **online** LLM alignment algorithms. Its key features include:
@@ -31,12 +32,12 @@ Oat 🌾 is a simple yet efficient framework for running **online** LLM alignmen
3132
* LLM-as-a-judge is supported via querying OpenAI API for model-based pairwise ranking.
3233
* **Ease of Use**: Oat's modular structure allows researchers to easily inherit and modify existing classes, enabling rapid prototyping and experimentation with new algorithms.
3334
* **Cutting-Edge Algorithms**: Oat implements state-of-the-art online algorithms, fostering innovation and fair benchmarking.
34-
* PPO (online RL) for math reasoning.
35+
* PPO/Dr.GRPO (online RL) for math reasoning.
3536
* Online DPO/SimPO/IPO for online preference learning.
3637
* Online exploration (active alignment) algorithms, including [SEA](https://arxiv.org/abs/2411.01493), APL and XPO.
3738

3839
## Installation
39-
In a python environment with supported versions (`>=3.8, <=3.10`), you could install oat via PyPI:
40+
In a python environment with supported versions (we recommend `3.10`), you could install oat via PyPI:
4041
```shell
4142
pip install vllm==0.7.2 && pip install oat-llm
4243
```
@@ -65,16 +66,20 @@ The benchmarking compares oat with the online DPO implementation from [huggingfa
6566
Please refer to [Appendix C of our paper](https://arxiv.org/pdf/2411.01493#page=17.64) for a detailed discussion of the benchmarking methods and results.
6667

6768
## Citation
68-
If you find this codebase useful for your research, please consider citing
69+
If you find this codebase useful for your research, please consider citing:
70+
71+
LLM online alignment framework:
6972
```
70-
@misc{liu2025oat,
71-
author = {Zichen Liu and Changyu Chen and Chao Du and Wee Sun Lee and Min Lin},
72-
title = {OAT: A research-friendly framework for LLM online alignment},
73-
howpublished = {[https://github.com/sail-sg/oat](https://github.com/sail-sg/oat)},
74-
year = {2025}
73+
@misc{
74+
liu2025oat,
75+
title={OAT: A research-friendly framework for LLM online alignment},
76+
author={Zichen Liu and Changyu Chen and Chao Du and Wee Sun Lee and Min Lin},
77+
howpublished={\url{https://github.com/sail-sg/oat}},
78+
year={2025}
7579
}
7680
```
7781

82+
Online exploration method:
7883
```
7984
@article{
8085
liu2024sea,

oat/algorithms/ppo.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import time
2222
from collections import defaultdict
2323
from dataclasses import dataclass, field
24-
from typing import List, Optional, Tuple
24+
from typing import List, Optional
2525

2626
import numpy as np
2727
import torch
@@ -229,7 +229,7 @@ def _init(self, args: PPOArgs, actors: List[ActorBase]) -> None:
229229
super()._init(args, actors)
230230
self.dataset_builder = TrajectoryDataset
231231
self.masked_aggregator = (
232-
functools.partial(masked_sum, constant_normalizer=1)
232+
functools.partial(masked_sum, constant_normalizer=args.generate_max_length)
233233
if args.critic_type == "drgrpo"
234234
else masked_mean
235235
)
@@ -311,16 +311,16 @@ def learn(self, learning_round: int):
311311

312312
return train_info
313313

314-
def compute_ppo_advantages(
315-
self, rewards, input_ids, att_mask, response_masks
316-
):
314+
def compute_ppo_advantages(self, rewards, input_ids, att_mask, response_masks):
317315
all_values = []
318316

319317
with torch.no_grad():
320318
for i in range(
321319
0, len(input_ids), self.args.mini_train_batch_size_per_device
322320
):
323-
batch_inds = torch.arange(i, i + self.args.mini_train_batch_size_per_device)
321+
batch_inds = torch.arange(
322+
i, i + self.args.mini_train_batch_size_per_device
323+
)
324324
## Forward critic network.
325325
batch_values = self.critic(
326326
input_ids=input_ids[batch_inds], attention_mask=att_mask[batch_inds]

0 commit comments

Comments
 (0)