diff --git a/src/genrl/trainer/grpo_trainer.py b/src/genrl/trainer/grpo_trainer.py index 115db7b..4b60523 100644 --- a/src/genrl/trainer/grpo_trainer.py +++ b/src/genrl/trainer/grpo_trainer.py @@ -41,6 +41,7 @@ class GRPOTrainerConfig: min_p: float | None = None repetition_penalty: float = 1.0 num_iterations: int = 1 + optimizer: str = "Adam" class GRPOLanguageTrainerModule(TrainerModule, LoggerMixin): @@ -67,9 +68,17 @@ def __init__(self, models: List[Any], config: GRPOTrainerConfig, **kwargs): self.args = config - self.optimizer = torch.optim.Adam( - self.model.parameters(), lr=self.args.learning_rate - ) + match config.optimizer: + case "Adam": + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.args.learning_rate + ) + case "SGD": + self.optimizer = torch.optim.SGD(self.model.parameters(), + lr=self.args.learning_rate, + momentum=0.9) + case _: + assert "For GRPO training only SGD or Adam optimizers are supported" # Tokenizers self.processing_class = kwargs.get("processing_class", None) @@ -446,7 +455,7 @@ def step( loss.backward() self.optimizer.step() - self.model.zero_grad() + self.model.zero_grad(set_to_none=True) metrics = {"train/loss": loss.cpu().mean().item()} metrics.update({"train/rewards": rewards.cpu().mean().item()})