Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO LoRA Single Device #2467

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions recipes/configs/dev/3B_lora_grpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Config for multi-node GRPO in dev/grpo_lora_finetune_distributed.py
# using a Llama3.2 3B Base model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-3.2-3B --output-dir /tmp/Llama-3.2-3B --ignore-patterns "original/consolidated.00.pth"
#
# It can be beneficial to first train the base model with SFT using the 3B_sft recipe.
#
# To launch on 1 device, run the following command from root:
# tune run --nproc_per_node 1 dev/grpo_lora_finetune_distributed --config dev/3B_lora_grpo
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 1 dev/grpo_lora_finetune_distributed --config dev/3B_lora_grpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>

name: grpo_llama3b_lora

output_dir: /tmp/checkpoints/${name}
base_model_path: /tmp/llama3B_gsm8k_sft_part0/epoch_0 # Use this to train from the slightly trained SFT model

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-3B/original/tokenizer.model
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.dev.grpo.gsm8k.gsm8k_dataset
partition: 1-9/10
seed: null
shuffle: False

# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_3b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
lora_rank: 64 # higher increases accuracy and memory
lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: ${base_model_path} # Base model from SFT includes merged lora.
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3_2

save_adapter_weights_only: False
resume_from_checkpoint: False
save_every_n_epochs: 1

# Fine-tuning arguments
batch_size: 1
grpo_samples: 12 # Reduced to fit in single device
forward_batch_size: 1
max_generated_tokens: 512
top_k: null
temperature: 1.0

ppo_epochs: 1
clip_grad_norm: 1.0

epochs: 10
optimizer:
_component_: torch.optim.AdamW
lr: 1e-5
fused: True
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 50
loss:
_component_: torchtune.dev.grpo.loss.GRPOSimpleLoss
kl_coeff: 0.01
epsilon: 0.2

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
compile: False # pytorch compile, set to true for better perf/memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: True
with_stack: True
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
115 changes: 115 additions & 0 deletions recipes/configs/dev/3B_lora_sft_for_grpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Config for multi-device SFT for reasoning in full_finetune_distributed.py
# using a Llama3.2 3B Base model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on a single device, run the following command from root:
# tune run --nproc_per_node 1 lora_finetune_distributed --config dev/3B_lora_grpo_sft
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 1 lora_finetune_distributed --config dev/3B_lora_grpo_sft checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>

name: llama3B_gsm8k_sft_part0

output_dir: /tmp/${name}

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-3B/original/tokenizer.model
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.dev.grpo.gsm8k.gsm8k_sft
partition: 0-0/10
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_3b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
lora_rank: 64 # higher increases accuracy and memory
lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-3B/
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Fine-tuning arguments
batch_size: 2
epochs: 1

optimizer:
_component_: torch.optim.AdamW
lr: 1e-5
fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
gradient_accumulation_steps: 1 # Use to increase effective batch size
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True


# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
1,063 changes: 1,063 additions & 0 deletions recipes/dev/grpo_lora_finetune_single_device.py

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,14 @@ class Recipe:
],
supports_distributed=True,
),
Recipe(
name="dev/grpo_lora_finetune_single_device",
file_path="dev/grpo_lora_finetune_single_device.py",
configs=[
Config(name="dev/3B_lora_grpo", file_path="dev/3B_lora_grpo.yaml"),
],
supports_distributed=False,
),
Recipe(
name="full_finetune_single_device",
file_path="full_finetune_single_device.py",
@@ -377,6 +385,9 @@ class Recipe:
name="lora_finetune_distributed",
file_path="lora_finetune_distributed.py",
configs=[
Config(
name="dev/3B_lora_grpo_sft", file_path="dev/3B_lora_sft_for_grpo.yaml"
),
Config(name="llama2/7B_lora", file_path="llama2/7B_lora.yaml"),
Config(name="llama2/13B_lora", file_path="llama2/13B_lora.yaml"),
Config(name="llama2/70B_lora", file_path="llama2/70B_lora.yaml"),