diff --git a/recipes/configs/dev/3B_lora_grpo.yaml b/recipes/configs/dev/3B_lora_grpo.yaml new file mode 100644 index 0000000000..3fd91f80e3 --- /dev/null +++ b/recipes/configs/dev/3B_lora_grpo.yaml @@ -0,0 +1,124 @@ +# Config for multi-node GRPO in dev/grpo_lora_finetune_distributed.py +# using a Llama3.2 3B Base model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-3.2-3B --output-dir /tmp/Llama-3.2-3B --ignore-patterns "original/consolidated.00.pth" +# +# It can be beneficial to first train the base model with SFT using the 3B_sft recipe. +# +# To launch on 1 device, run the following command from root: +# tune run --nproc_per_node 1 dev/grpo_lora_finetune_distributed --config dev/3B_lora_grpo +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 1 dev/grpo_lora_finetune_distributed --config dev/3B_lora_grpo checkpointer.checkpoint_dir= + +name: grpo_llama3b_lora + +output_dir: /tmp/checkpoints/${name} +base_model_path: /tmp/llama3B_gsm8k_sft_part0/epoch_0 # Use this to train from the slightly trained SFT model + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-3B/original/tokenizer.model + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.dev.grpo.gsm8k.gsm8k_dataset + partition: 1-9/10 +seed: null +shuffle: False + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.lora_llama3_2_3b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: ${base_model_path} # Base model from SFT includes merged lora. + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3_2 + +save_adapter_weights_only: False +resume_from_checkpoint: False +save_every_n_epochs: 1 + +# Fine-tuning arguments +batch_size: 1 +grpo_samples: 12 # Reduced to fit in single device +forward_batch_size: 1 +max_generated_tokens: 512 +top_k: null +temperature: 1.0 + +ppo_epochs: 1 +clip_grad_norm: 1.0 + +epochs: 10 +optimizer: + _component_: torch.optim.AdamW + lr: 1e-5 + fused: True +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 50 +loss: + _component_: torchtune.dev.grpo.loss.GRPOSimpleLoss + kl_coeff: 0.01 + epsilon: 0.2 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +compile: False # pytorch compile, set to true for better perf/memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: True + with_stack: True + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/dev/3B_lora_sft_for_grpo.yaml b/recipes/configs/dev/3B_lora_sft_for_grpo.yaml new file mode 100644 index 0000000000..1d9870566a --- /dev/null +++ b/recipes/configs/dev/3B_lora_sft_for_grpo.yaml @@ -0,0 +1,115 @@ +# Config for multi-device SFT for reasoning in full_finetune_distributed.py +# using a Llama3.2 3B Base model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on a single device, run the following command from root: +# tune run --nproc_per_node 1 lora_finetune_distributed --config dev/3B_lora_grpo_sft +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 1 lora_finetune_distributed --config dev/3B_lora_grpo_sft checkpointer.checkpoint_dir= + +name: llama3B_gsm8k_sft_part0 + +output_dir: /tmp/${name} + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-3B/original/tokenizer.model + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.dev.grpo.gsm8k.gsm8k_sft + partition: 0-0/10 +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.lora_llama3_2_3b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-3B/ + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3_2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 + +optimizer: + _component_: torch.optim.AdamW + lr: 1e-5 + fused: True +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 +gradient_accumulation_steps: 1 # Use to increase effective batch size +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/dev/grpo_lora_finetune_single_device.py b/recipes/dev/grpo_lora_finetune_single_device.py new file mode 100644 index 0000000000..86a9b55d7d --- /dev/null +++ b/recipes/dev/grpo_lora_finetune_single_device.py @@ -0,0 +1,1063 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +from functools import partial +from typing import Any, Dict, List, Optional, Union +from warnings import warn + +import torch +import torchtune.modules.common_utils as common_utils +from omegaconf import DictConfig, ListConfig +from torch import nn +from torch.optim import Optimizer +from torchdata.stateful_dataloader import StatefulDataLoader +from torchtune import config, generation, modules, rlhf, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.datasets import ConcatDataset +from torchtune.dev.grpo.generation import generate +from torchtune.dev.grpo.rewards import batch_shaped_correctness_reward +from torchtune.dev.grpo.types import GRPOStats, GRPOTrajectory +from torchtune.modules import local_kv_cache +from torchtune.modules.peft import ( + disable_adapter, + get_adapter_params, + get_adapter_state_dict, + get_lora_module_names, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY +from torchtune.training.lr_schedulers import get_lr +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class LoraGRPOFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + LoRA finetuning recipe for dense transformer-based LLMs such as Llama2, trained with GRPO. + This recipe is optimized for single GPU training. Training on CPU is not supported. + + Features: + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = training.set_seed( + seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None) + ) + self.total_epochs = cfg.epochs + self.global_step = 0 + self._steps_run = 0 + self._epochs_run = 0 + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + + self._rng = torch.Generator(self._device).manual_seed(self.seed) + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + should_load_recipe_state=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + # Check for adapter weights in checkpoint + if training.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self._epochs_run = ckpt_dict[training.EPOCHS_KEY] + self._rng.set_state(ckpt_dict[training.RNG_KEY]) + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe. This includes training state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + # hack to toggle to the low cpu ram version of the reparametrize_as_dtype + # hook based on the config. + common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._compile = cfg.get("compile", False) + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + compile_model=self._compile, + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[training.ADAPTER_KEY] + if training.ADAPTER_KEY in checkpoint_dict + else None + ), + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn) + + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + collate_name = cfg.get( + "collate_fn", "torchtune.dev.grpo.data.padded_collate_rl" + ) + self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + dataloader_state_dict=( + checkpoint_dict[training.DATALOADER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader. + # This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = len(self._dataloader) + self.global_step = self._epochs_run * self._steps_per_epoch + + # Setup lr scheduler + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.get("lr_scheduler", None), + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # RL params + self.grpo_samples = cfg.grpo_samples + self._temperature = cfg.temperature + self._top_k = cfg.top_ks + self._max_generated_tokens = cfg.max_generated_tokens + self.batch_size = cfg.batch_size + self._forward_batch_size = cfg.forward_batch_size + + self._ppo_epochs = cfg.ppo_epochs + + self._save_every_n_epochs = cfg.save_every_n_epochs + + if cfg.get("stop_token_ids", False): + stop_token_ids = cfg.stop_token_ids + if self._tokenizer.eos_id not in stop_token_ids: + warn( + f"tokenizer eos_id ({self._tokenizer.eos_id}) is not in stop_token_ids ({stop_token_ids})." + "This may lead to unexpected behaviour." + ) + else: + if not hasattr(self._tokenizer, "stop_tokens"): + warn( + "No stop tokens defined in tokenizer, and no stop_token_ids provided. This may lead to unexpected behaviour." + ) + stop_token_ids = [] + else: + stop_token_ids = self._tokenizer.stop_tokens + self._stop_token_ids = torch.tensor(stop_token_ids, device=self._device) + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optimizer: + """ + Set up the optimizer based on the provided configuration. + """ + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer is initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: Optional[DictConfig], + num_training_steps: int, + last_epoch: int, + ) -> Optional[Optimizer]: + """ + Set up the learning rate scheduler based on the provided configuration. + + Args: + cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. + num_training_steps (int): The total number of training steps. + last_epoch (int): The index of the last epoch. + + Returns: + lr_scheduler (Optional[Optimizer]): The learning rate scheduler. + """ + if cfg_lr_scheduler is None: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + + optimizer = self._optimizer + + # Instantiate the learning rate scheduler + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f"Profiler config after instantiation: {profiler_cfg}") + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + collate_fn: str, + dataloader_state_dict: Optional[Dict[str, Any]] = None, + ) -> StatefulDataLoader: + """ + All data related setup happens here. This recipe currently supports only + map-style datasets. If a state_dict is provided (meaning we are resuming a training run), + it is loaded into the dataloader. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = getattr(ds, "packed", False) + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + # Instantiate collate_fn + collate_fn = _get_component_from_path(collate_fn) + + dataloader = StatefulDataLoader( + dataset=ds, + shuffle=shuffle, + batch_size=batch_size, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ) + ), + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + ) + if dataloader_state_dict is not None: + dataloader.load_state_dict(dataloader_state_dict) + # B/c we currently only save at epoch boundaries, if we cut the previous epoch short + # we need to force the dataloader to finish the last iteration before it's actually used + list(dataloader) + return dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + ckpt_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # if training is in-progress, checkpoint the optimizer state as well + if intermediate_checkpoint: + ckpt_dict.update( + { + training.OPT_KEY: self._optimizer.state_dict(), + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self._epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self._steps_per_epoch, + training.DATALOADER_KEY: self._dataloader.state_dict(), + } + ) + + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) + ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + + if not self._save_adapter_weights_only: + # Construct the full state dict with LoRA weights merged into base LLM weights + + # Move to CPU to avoid a copy on GPU + state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} + + merged_state_dict = get_merged_lora_ckpt( + state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + + ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) + + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + compile_model: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model with + the right dtype + """ + with training.set_default_dtype(self._dtype), self._device: + model = config.instantiate(cfg_model) + + # Configure LoRA parameters + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + + # Load base model weights + base_missing, base_unexpected = model.load_state_dict( + base_model_state_dict, strict=False + ) + + # Get adapter parameters after model is loaded + self.adapter_params = get_adapter_params(model) + self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()]) + + # Set trainable parameters + set_trainable_params(model, self.adapter_params) + + if compile_model: + training.compile_model(model) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # Ensure model is on the correct device + model = model.to(self._device) + + # Initialize RoPE and DoRA if needed + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + # Ensure all buffers are on the correct TODO: may not be needed, try removing. + for buffer_name, buffer in m._buffers.items(): + if buffer is not None and buffer.device != self._device: + m._buffers[buffer_name] = buffer.to(self._device) + + # Initialize DoRA magnitude if needed + if self._is_dora: + for m in model.modules(): + if hasattr(m, "initialize_dora_magnitude"): + m.initialize_dora_magnitude() + + # Load adapter weights if provided + if lora_weights_state_dict: + lora_missing, lora_unexpected = model.load_state_dict( + lora_weights_state_dict, strict=False + ) + + # Log any issues with loading + if lora_missing: + log.warning( + f"Missing keys when loading adapter weights: {lora_missing}" + ) + if lora_unexpected: + log.warning( + f"Unexpected keys when loading adapter weights: {lora_unexpected}" + ) + else: + lora_missing, lora_unexpected = None, None + + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + + disable_dropout(model) + + log.info(f"Model is initialized with precision {self._dtype}.") + + if self._device.type != "cpu": + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + return model + + def generate_trajectory( + self, input_ids: torch.Tensor, answers: List[str] + ) -> GRPOTrajectory: + """ + Generates a trajectory given the current policy model, using disable_adapter with the main model + to be the reference instead, the reward function, and batch of inputs. This is done over the following steps: + + 1: Generate responses, and logits corresponding to the responses using the current policy, + generating (query, response) pairs. + 2. Estimate logprobs of the generated responses using the current policy. + 3. Compute rewards and successes for the generated responses. + 4. Estimate advantages using GRPO. + 5. Replace any tokens in the response after the first stop token (usually EOS token) with padding, + producing truncated responses. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + answers (List[str]): list of answers corresponding to the input_ids + + Returns: + Trajectory: An instance of :class:`~torchtune.rlhf.GRPOTrajectory` comprising + the current trajectory. + """ + batch_size, context_length = input_ids.shape + grpo_size = self.grpo_samples + + batch_input_ids = input_ids[:, None, :].expand(-1, grpo_size, -1) # [B, G, L] + batch_input_ids = batch_input_ids.reshape(batch_size * grpo_size, -1) + + # step 1: generate responses, and logits corresponding to the responses using the current policy + with local_kv_cache( + model=self._model, + batch_size=batch_size * grpo_size, + device=self._device, + dtype=self._dtype, + decoder_max_seq_len=context_length + self._max_generated_tokens, + ): + # Make sure all tensors are on the correct device + batch_input_ids = batch_input_ids.to(self._device) + + # Use the device-specific stop tokens + stop_tokens = self._tokenizer.stop_tokens + if isinstance(stop_tokens, torch.Tensor): + stop_tokens = stop_tokens.to(self._device) + + query_responses, _ = generate( # [B x G, L], [B x G, L, V] + model=self._model, + prompt=batch_input_ids, + max_generated_tokens=self._max_generated_tokens, + temperature=self._temperature, + top_k=self._top_k, + pad_id=self._tokenizer.pad_id, + rng=self._rng, + stop_tokens=stop_tokens, + return_logits=False, + ) + + torch.cuda.empty_cache() + + responses = query_responses[:, context_length:].clone() + query_response_padding_masks = query_responses != self._tokenizer.pad_id + + # step 1.1 create attention masks and position IDs for any padding tokens in inputs, used for future forward passes + masks = generation.get_causal_mask_from_padding_mask( + query_response_padding_masks + ).to(self._device) + position_ids = generation.get_position_ids_from_padding_mask( + query_response_padding_masks + ).to(self._device) + + del query_response_padding_masks + + logits = self._model(query_responses, input_pos=position_ids, mask=masks) + + # step 2. estimate logprobs of the responses using the current policy + logits = logits[:, context_length - 1 :] + logprobs = rlhf.batched_logits_to_logprobs(logits, responses, self._temperature) + + del logits + torch.cuda.empty_cache() + + # step 2.1 estimate logprobs of the responses using the reference policy (main model with adapters disabled) + with disable_adapter(self._model): + ref_logits = self._model( + query_responses, input_pos=position_ids, mask=masks + ) + ref_logits = rlhf.truncate_sequence_for_logprobs(ref_logits, context_length) + ref_logprobs = rlhf.batched_logits_to_logprobs( + ref_logits, responses, self._temperature + ) + + del ref_logits + torch.cuda.empty_cache() + + # step 4. replace any tokens in the responses after the first stop token (usually EOS token) with padding + # resulting in truncated responses + ( + response_padding_masks, + responses, + ) = rlhf.truncate_sequence_at_first_stop_token( # [B x G, L] + responses, self._stop_token_ids, self._tokenizer.pad_id + ) + + # responses :: [B x G, L] + responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L] + + rewards, successes = batch_shaped_correctness_reward( + self._tokenizer, responses, answers + ) # [B, G] + rewards = rewards.to(self._device) + successes = successes.to(self._device) + + advantages = (rewards - rewards.mean(1, keepdim=True)) / ( + rewards.std(1, keepdim=True) + 1e-4 + ) + advantages = advantages.reshape(batch_size * grpo_size) # flatten + + del responses + torch.cuda.empty_cache() + + seq_lens = training.get_unmasked_sequence_lengths(response_padding_masks) + + # step 6. mask out all the invalid values in the trajectory due to padding tokens + logprobs[response_padding_masks] = 1.0 + ref_logprobs[response_padding_masks] = 1.0 + + return GRPOTrajectory( + query_responses=query_responses, + logprobs=logprobs, + ref_logprobs=ref_logprobs, + rewards=rewards.reshape(batch_size * grpo_size), + successes=successes.reshape(batch_size * grpo_size), + advantages=advantages, + masks=masks, + position_ids=position_ids, + response_padding_masks=response_padding_masks, + seq_lens=seq_lens, + ) + + def generate_trajectory_batched( + self, input_ids: torch.Tensor, answers: List[str] + ) -> GRPOTrajectory: + """ + Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. + See ``generate_trajectory`` for more details. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + answers: (List[str]): list of answers corresponding to the input_ids + + Returns: + Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory`, comprising + the current trajectory. + """ + trajectories: List[GRPOTrajectory] = [] + with torch.no_grad(): + for batch_start in range(0, self.batch_size, self._forward_batch_size): + batch_input_ids = input_ids[ + batch_start : batch_start + self._forward_batch_size + ] + batch_answers = answers[ + batch_start : batch_start + self._forward_batch_size + ] + torch.cuda.empty_cache() + trajectories.append( + self.generate_trajectory(batch_input_ids, batch_answers) + ) + torch.cuda.empty_cache() + return GRPOTrajectory(*map(torch.cat, zip(*trajectories))) + + def grpo_step( + self, + trajectory: GRPOTrajectory, + context_length: int, + ) -> GRPOStats: + """ + Perform a single GRPO optimization step over a batch of trajectories and corresponding advantages and returns. + + Args: + trajectory (Trajectory): a batch of trajectories + context_length (int): input ids sequence length + + Returns: + GRPOStats: An instance of :class:`~torchtune.rlhf.PPOStats`, a NamedTuple containing: + - loss (torch.Tensor): The total PPO loss. + - ratios (torch.Tensor): The ratio between the current and old policy probabilities. + - clipfrac (torch.Tensor): The fraction of ratios that were clipped. + - approx_policy_kls: Average estimated KL divergence between the policy before and after the optimisation step. + + """ + # estimate logprobs from the policy at the current optimisation step + + torch.cuda.empty_cache() + + pi_logits = self._model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + + pi_logits = rlhf.truncate_sequence_for_logprobs(pi_logits, context_length) + pi_logprobs = rlhf.batched_logits_to_logprobs( + pi_logits, + trajectory.query_responses[:, context_length:], + self._temperature, + chunk_size=1, + ) + + pi_logprobs[trajectory.response_padding_masks] = 1.0 + + del pi_logits + torch.cuda.empty_cache() + + # calculate grpo loss + loss, policy_loss, kl_loss, ratios, clipfrac = self._loss_fn( + trajectory.logprobs, + pi_logprobs, + trajectory.ref_logprobs, + trajectory.advantages, + padding_masks=~trajectory.response_padding_masks, + ) + + torch.cuda.empty_cache() + loss.backward() + + with torch.no_grad(): + approx_policy_kls = ( + 0.5 * (pi_logprobs - trajectory.logprobs).pow(2) + ).mean() + + return GRPOStats( + loss, + policy_loss, + kl_loss, + ratios, + clipfrac, + approx_policy_kls, + ) + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + grad_norm = None + + training_completed = False + stop_profiler_idx = ( + self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ) + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self._epochs_run, self.total_epochs): + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + tokens = batch["tokens"] # type: ignore + answers = batch["answers"] # type: ignore + tokens = tokens.to(self._device) # [B, P] + + _, context_length = tokens.shape + + trajectory = self.generate_trajectory_batched(tokens, answers) + + grpo_stats: list[GRPOStats] = [] + for _ in range(self._ppo_epochs): + + step_stats = self.grpo_step(trajectory, context_length) + + grpo_stats.append(step_stats) + + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + self.global_step += 1 + + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + self._steps_run += 1 + if self._steps_run % self._log_every_n_steps == 0: + extra_metrics = {} + extra_metrics["lr"] = get_lr(self._optimizer) + if grad_norm is not None: + extra_metrics["grad_norm"] = grad_norm + + self.log_metrics( + trajectory, + GRPOStats(*map(torch.stack, zip(*grpo_stats))), + **extra_metrics, + ) + + self.cleanup_after_step(trajectory, grpo_stats) + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == stop_profiler_idx + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler - TODO move to distributed recipe. + self._profiler.step() + pbar.update(1) + + self._epochs_run += 1 + if self._epochs_run % self._save_every_n_epochs == 0: + self.save_checkpoint(curr_epoch) + + self._profiler.stop() + + def log_metrics( + self, trajectory: GRPOTrajectory, grpo_stats: GRPOStats, **extras + ) -> None: + """ + Log metrics and statistics for the current step to the metric logger. + """ + rewards = trajectory.rewards.mean() + successes = trajectory.successes.mean() + + log_dict = { + "rewards": rewards, + "successes": successes, + "num_stop_tokens": trajectory.response_padding_masks.any(-1).sum(), + "loss": grpo_stats.loss.mean(), + "policy_loss": grpo_stats.policy_loss.mean(), + "kl_loss": grpo_stats.kl_loss.mean(), + "clipfrac": grpo_stats.clipfrac.mean(), + "ratios": grpo_stats.ratios.mean(), + "approx_policy_kl": grpo_stats.approx_policy_kls.mean(), + "response_lengths": trajectory.seq_lens.float().mean(), + **extras, + } + + if self._device.type == "cuda" and self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + + self._metric_logger.log_dict(log_dict, step=self.global_step) + + def cleanup(self) -> None: + self._metric_logger.close() + + def cleanup_after_step( + self, + trajectory: GRPOTrajectory, + l_grpo_stats: list[GRPOStats], + ) -> None: + for v in trajectory: + del v + del trajectory + for g in l_grpo_stats: + for v in g: + del v + del g + del l_grpo_stats + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + + recipe = LoraGRPOFinetuneRecipeSingleDevice(cfg=cfg) + config.log_config(recipe_name="LoraGRPOFinetuneRecipeSingleDevice", cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 17e0a6d379..80994653b7 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -31,6 +31,14 @@ class Recipe: ], supports_distributed=True, ), + Recipe( + name="dev/grpo_lora_finetune_single_device", + file_path="dev/grpo_lora_finetune_single_device.py", + configs=[ + Config(name="dev/3B_lora_grpo", file_path="dev/3B_lora_grpo.yaml"), + ], + supports_distributed=False, + ), Recipe( name="full_finetune_single_device", file_path="full_finetune_single_device.py", @@ -377,6 +385,9 @@ class Recipe: name="lora_finetune_distributed", file_path="lora_finetune_distributed.py", configs=[ + Config( + name="dev/3B_lora_grpo_sft", file_path="dev/3B_lora_sft_for_grpo.yaml" + ), Config(name="llama2/7B_lora", file_path="llama2/7B_lora.yaml"), Config(name="llama2/13B_lora", file_path="llama2/13B_lora.yaml"), Config(name="llama2/70B_lora", file_path="llama2/70B_lora.yaml"),