diff --git a/recipes/dev/grpo_full_finetune_distributed.py b/recipes/dev/grpo_full_finetune_distributed.py index 145f5cd661..591ddd5c36 100644 --- a/recipes/dev/grpo_full_finetune_distributed.py +++ b/recipes/dev/grpo_full_finetune_distributed.py @@ -646,9 +646,15 @@ def generate_trajectory( # Do some reward modelingggggggg # responses :: [B x G, L] responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L] - rewards, successes = batched_rewards(self._tokenizer, responses, answers) - rewards = rewards.to(self._device) # [B, G] - successes = successes.to(self._device) # [B, G] + rewards, successes, metadata = batched_rewards( + self._tokenizer, responses, answers, device=self._device + ) + rewards = rewards.to(self._device) # [B, G, num_reward_funcs] + successes = successes.to(self._device) # [B, G, num_reward_funcs] + + # Aggregate rewards and successes across reward functions + rewards = rewards.sum(dim=-1) # [B, G] + successes = successes.sum(dim=-1) # [B, G] advantages = (rewards - rewards.mean(1, keepdim=True)) / ( rewards.std(1, keepdim=True) + 1e-4 @@ -672,6 +678,7 @@ def generate_trajectory( position_ids=position_ids, response_padding_masks=response_padding_masks, seq_lens=training.get_unmasked_sequence_lengths(response_padding_masks), + answers=answers, ) def generate_trajectory_batched( @@ -703,7 +710,22 @@ def generate_trajectory_batched( self.generate_trajectory(batch_input_ids, batch_answers) ) torch.cuda.empty_cache() - return GRPOTrajectory(*map(torch.cat, zip(*trajectories))) + + # Concatenate all trajectory fields except answers (which is a list of strings) + concatenated_fields = {} + for field_name in trajectories[0]._fields: + if field_name == "answers": + # Concatenate lists of answers + concatenated_fields[field_name] = [] + for traj in trajectories: + concatenated_fields[field_name].extend(traj.answers) + else: + # Concatenate tensors + concatenated_fields[field_name] = torch.cat( + [getattr(traj, field_name) for traj in trajectories] + ) + + return GRPOTrajectory(**concatenated_fields) def grpo_step( self, @@ -771,6 +793,7 @@ def grpo_step( ratios, clipfrac, approx_policy_kls, + None, # metadata ) def train(self) -> None: @@ -853,9 +876,21 @@ def train(self) -> None: if grad_norm is not None: extra_metrics["grad_norm"] = grad_norm + # Concatenate GRPOStats fields properly + concatenated_stats = {} + for field_name in grpo_stats[0]._fields: + if field_name == "metadata": + # Handle metadata separately (it's None, so just use None) + concatenated_stats[field_name] = None + else: + # Stack tensors + concatenated_stats[field_name] = torch.stack( + [getattr(stat, field_name) for stat in grpo_stats] + ) + self.log_metrics( trajectory, - GRPOStats(*map(torch.stack, zip(*grpo_stats))), + GRPOStats(**concatenated_stats), **extra_metrics, ) diff --git a/torchtune/dev/rl/rewards.py b/torchtune/dev/rl/rewards.py index 95c45ee9b0..42dc40a74f 100644 --- a/torchtune/dev/rl/rewards.py +++ b/torchtune/dev/rl/rewards.py @@ -298,8 +298,8 @@ def batched_rewards( for b in range(batch_size): for g in range(grpo_size): - - answer = answers[b][g] + # print(answers) + answer = answers[b] text_completion = tokenizer.decode(completions[b, g].tolist()) diff --git a/torchtune/dev/rl/types.py b/torchtune/dev/rl/types.py index b0aae365ff..3120d0b826 100644 --- a/torchtune/dev/rl/types.py +++ b/torchtune/dev/rl/types.py @@ -19,6 +19,8 @@ class GRPOTrajectory(NamedTuple): logprobs (torch.Tensor): Log probabilities of the generated responses with shape [B x G, L]. ref_logprobs (torch.Tensor): Log probabilities of the generated responses using the reference policy with shape [B x G, L]. advantages (torch.Tensor): Advantage estimates for the generated responses with shape [B x G]. + rewards (torch.Tensor): Reward values for the generated responses with shape [B x G]. + successes (torch.Tensor): Success indicators for the generated responses with shape [B x G]. masks (torch.Tensor): Attention masks for input ids-generated responses pairs with shape [B x G, P+L, P+L]. position_ids (torch.Tensor): Position IDs for input ids-generated responses pairs with shape [B x G, P+L]. response_padding_masks (torch.Tensor): Padding masks for the truncated and padded generated responses with shape [B x G, L]. @@ -30,6 +32,8 @@ class GRPOTrajectory(NamedTuple): logprobs: torch.Tensor = None # [B x G, L] ref_logprobs: torch.Tensor = None # [B x G, L] advantages: torch.Tensor = None # [B x G] + rewards: torch.Tensor = None + successes: torch.Tensor = None masks: torch.Tensor = None # [B x G, P+L, P+L] position_ids: torch.Tensor = None # [B x G, P+L] response_padding_masks: torch.Tensor = None # [B x G, L]