diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml index 59d9ec3..8fca722 100644 --- a/.github/workflows/wheel.yml +++ b/.github/workflows/wheel.yml @@ -109,9 +109,10 @@ jobs: test-modal-recompute: # These tests verify that recomputation options do not change the results at all - name: Test recompute - ${{ matrix.recompute.name }} - ${{ matrix.dtype.name }} + name: Recompute - ${{ matrix.recompute.name }} - ${{ matrix.dtype.name }} needs: deploy-modal runs-on: ubuntu-latest + if: false strategy: fail-fast: false max-parallel: 3 @@ -129,7 +130,7 @@ jobs: args: "--offload-opt-m --offload-opt-v --offload-master" # While not strictly a recomputation, chunked attention should be bitwise identical, too - name: "Chunked attention" - args: "--recompute-att --attn-bwd-chunks 4" + args: "--recompute-att --attn-bwd-chunks=2" dtype: - name: "BF16" args: "--matmul-dtype=bf16" @@ -158,6 +159,7 @@ jobs: name: Test fixed - ${{ matrix.config.name }} needs: deploy-modal runs-on: ubuntu-latest + if: false strategy: fail-fast: false max-parallel: 3 @@ -192,11 +194,59 @@ jobs: - name: Run test on Modal run: python3 scripts/modal_test_ci.py ${{ matrix.config.args }} + test-modal-multi-gpu: + name: Test Multi-GPU - ${{ matrix.config.name }} + #needs: + # - test-modal-fixed + # - test-modal-recompute + needs: deploy-modal + + runs-on: ubuntu-latest + strategy: + fail-fast: false + max-parallel: 3 + matrix: + config: + - name: "BF16 weight sharding" + func: "recompute" + args: "--matmul-dtype bf16 --shard-weights" + - name: "FP8 + memcpy" + func: "recompute" + args: "--matmul-dtype e4m3 --shard-weights --memcpy-all-gather" + - name: "FP8 + persistent quants" + func: "recompute" + args: "--matmul-dtype e4m3 --shard-weights --persistent-quants --offload-quants" + - name: "Fixed BF16" + func: "fixed" + args: "bf16" + - name: "Fixed FP8" + func: "fixed" + args: "e4m3" + - name: "Fixed BF16 Shard gradient" + func: "fixed" + args: "bf16 --shard-gradients" + - name: "Fixed FP8 Shard gradient" + func: "fixed" + args: "e4m3 --shard-gradients" + steps: + - name: Checkout code + uses: actions/checkout@v4 + + # Note: No need to download wheel again, it's already in the deployed image + + - name: Install Modal + run: pip install modal + + - name: Set Modal token + run: modal token set --token-id ${{ secrets.MODAL_TOKEN_ID }} --token-secret ${{ secrets.MODAL_TOKEN_SECRET }} + + - name: Run test on Modal + run: python3 scripts/modal_test_ci.py ${{ matrix.config.func }} ${{ matrix.config.args }} --gpus 2 + release: if: github.event_name == 'workflow_dispatch' || startsWith(github.ref, 'refs/heads/release-') || startsWith(github.ref, 'refs/tags/') needs: - - test-modal-recompute - - test-modal-fixed + - test-modal-multi-gpu runs-on: ubuntu-latest permissions: diff --git a/scripts/modal_test_app.py b/scripts/modal_test_app.py index 3feaa36..5e8a58d 100644 --- a/scripts/modal_test_app.py +++ b/scripts/modal_test_app.py @@ -5,6 +5,7 @@ Usage: modal run modal_test_app.py [-- test args...] """ +import argparse import io import sys from pathlib import Path @@ -68,13 +69,7 @@ def compare_and_create_report(result, expected): } -@app.function( - gpu="L4", - memory=8192, - timeout=300, - image=image, -) -def run_recompute_test(test_args: list[str]): +def _run_recompute_test(test_args: list[str]): """Run recomputation tests on Modal.""" from pyllmq.tests.run import parse_args, run_training from pyllmq.tests.recompute import disable_recompute @@ -89,6 +84,26 @@ def run_recompute_test(test_args: list[str]): return compare_and_create_report(test_run, ref_run) +@app.function( + gpu="L4", + memory=8192, + timeout=300, + image=image, +) +def run_recompute_test(test_args: list[str]): + return _run_recompute_test(test_args) + + +@app.function( + gpu="L4:2", + memory=8192, + timeout=300, + image=image, +) +def run_recompute_test_x2(test_args: list[str]): + return _run_recompute_test(test_args) + + def run_with_config(test_args: list[str]): from pyllmq.tests.run import parse_args, run_training config = parse_args(test_args) @@ -144,6 +159,56 @@ def run_fixed_result_test(dtype: str = "bf16"): return report +@app.function( + gpu="L4:2", + memory=8192, + timeout=300, + image=image, +) +def run_fixed_result_test_x2(dtype: str = "bf16", shard_gradients: bool = False): + from pyllmq.tests.run import RunResult + + print(f"Launching Modal fixed_result test with dtype: {dtype}") + + if dtype == "e5m2": + args = [f"--matmul-dtype=e4m3", "--gradient-dtype=e5m2"] + else: + args = [f"--matmul-dtype={dtype}"] + + if shard_gradients: + args += ["--shard-gradients"] + + args += ["--gpus=2"] + + """Run tests on Modal.""" + result = run_with_config(args) + if dtype == "bf16": + expected = { + "losses": [3.4119365215301514, 3.394049882888794, 3.4545254707336426, 3.0694894790649414, 3.007321834564209, 3.3855042457580566, 3.368359088897705, 3.421376943588257, 3.1316380500793457, 3.2092301845550537, 3.01995849609375], + "norms": [5.42860746383667, 5.231578826904297, 5.656546115875244, 4.69525146484375, 4.644282341003418, 5.210570812225342, 5.396310806274414, 4.417316913604736, 4.4374165534973145, 4.28884220123291], + } + elif dtype == "e4m3": + expected = { + "losses": [3.4303817749023438, 3.43670392036438, 3.483766555786133, 3.0972299575805664, 3.0326924324035645, 3.409470558166504, 3.3872318267822266, 3.4421865940093994, 3.152552843093872, 3.229149341583252, 3.0453014373779297], + "norms": [5.8067474365234375, 8.371203422546387, 5.1532464027404785, 4.662567615509033, 4.763641834259033, 4.693724632263184, 5.259921073913574, 4.645272731781006, 4.207671165466309, 4.346331596374512] + } + elif dtype == "e5m2": + expected = { + "losses": [3.4303817749023438, 3.4341166019439697, 3.4837355613708496, 3.09706711769104, 3.0316996574401855, 3.410259962081909, 3.3873462677001953, 3.441790819168091, 3.1511523723602295, 3.2284598350524902, 3.0418832302093506], + "norms": [5.7736382484436035, 8.317730903625488, 5.149673938751221, 4.641636371612549, 4.685691833496094, 4.650301933288574, 5.228470325469971, 4.605687618255615, 4.183129787445068, 4.276437759399414], + } + else: + raise ValueError(f"Unknown dtype: {dtype}") + + report = compare_and_create_report(result, RunResult(**expected)) + if not report["passed"]: + import json + import dataclasses + # this helps with debugging/updating in case of failure + print(json.dumps(dataclasses.asdict(result))) + return report + + @app.function( gpu="L4", memory=8192, @@ -166,10 +231,21 @@ def run_torch_compare_step(test_args: list): } +def _get_gpu_arg(args: tuple[str, ...]) -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--gpus", type=int, default="1") + parsed_args, _ = parser.parse_known_args(args) + return parsed_args.gpus + + @app.local_entrypoint() def test_recompute(*test_args: str): print(f"Launching Modal recomputation test with args: {test_args}") - result = run_recompute_test.remote(list(test_args)) + gpus = _get_gpu_arg(test_args) + if gpus == 2: + result = run_recompute_test_x2.remote(list(test_args)) + else: + result = run_recompute_test.remote(list(test_args)) # Print the comparison report print("\n" + result["report"]) @@ -188,15 +264,16 @@ def test_torch_step(*test_args: str): @app.local_entrypoint() -def test_fixed(dtype: str = "bf16"): +def test_fixed(dtype: str = "bf16", gpus: int = 1, shard_gradients: bool = False): print(f"Launching Modal test with dtype: {dtype}") - result = run_fixed_result_test.remote(dtype) + if gpus == 2: + result = run_fixed_result_test_x2.remote(dtype, shard_gradients) + else: + assert shard_gradients == False, "shard_gradient only supported for 2 gpus" + result = run_fixed_result_test.remote(dtype) # Print the comparison report print("\n" + result["report"]) if not result["passed"]: sys.exit(1) - - - diff --git a/scripts/modal_test_ci.py b/scripts/modal_test_ci.py index 3d56380..c9afa69 100644 --- a/scripts/modal_test_ci.py +++ b/scripts/modal_test_ci.py @@ -5,33 +5,56 @@ Usage: python run_modal_tests.py [test args...] """ +import argparse import sys import modal +def _get_gpu_arg(args: list[str]) -> tuple[int, list[str]]: + parser = argparse.ArgumentParser() + parser.add_argument("--gpus", type=int, default="1") + parsed_args, rest = parser.parse_known_args(args) + return parsed_args.gpus, rest + + if __name__ == "__main__": # Reference the already-deployed app app = modal.App.lookup("llmq-test", create_if_missing=False) test_name = sys.argv[1] + gpus, rest = _get_gpu_arg(sys.argv[2:]) + test_args_pos = [] + test_args_kw = {} if test_name == "recompute": # Get the run_recompute_test function from the deployed app - test_fn = modal.Function.from_name("llmq-test", "run_recompute_test") - test_args = sys.argv[2:] + if gpus == 2: + test_fn = modal.Function.from_name("llmq-test", "run_recompute_test_x2") + else: + test_fn = modal.Function.from_name("llmq-test", "run_recompute_test") + test_args_pos = [sys.argv[2:]] elif test_name == "fixed": - test_fn = modal.Function.from_name("llmq-test", "run_fixed_result_test") - test_args = sys.argv[2] + parser = argparse.ArgumentParser() + parser.add_argument("dtype", type=str) + parser.add_argument("--shard-gradient", action="store_true") + parsed_args, rest = parser.parse_known_args(rest) + if gpus == 2: + test_fn = modal.Function.from_name("llmq-test", "run_fixed_result_test_x2") + test_args_kw = {"dtype": parsed_args.dtype, "shard_gradients": parsed_args.shard_gradient} + else: + assert not parsed_args.shard_gradient, "shard_gradient only supported for 2 gpus" + test_fn = modal.Function.from_name("llmq-test", "run_fixed_result_test") + test_args_kw = {"dtype": parsed_args.dtype} elif test_name == "torch-step": test_fn = modal.Function.from_name("llmq-test", "run_torch_compare_step") - test_args = sys.argv[2:] + test_args_pos = [sys.argv[2:]] else: raise RuntimeError(f"Unknown test type {test_name}") # Get test arguments from command line - print(f"Launching Modal test with args: {test_args}") - result = test_fn.remote(test_args) + print(f"Launching Modal test with args: {test_args_pos}, {test_args_kw}") + result = test_fn.remote(*test_args_pos, **test_args_kw) # Print the comparison report print("\n" + result["report"]) diff --git a/scripts/train.py b/scripts/train.py index dc369b7..569fb77 100755 --- a/scripts/train.py +++ b/scripts/train.py @@ -83,74 +83,7 @@ def run_evaluation(trainer: pyllmq.LLMQTrainer, eval_loader: pyllmq.DataLoader, def parse_args(): parser = argparse.ArgumentParser(description="Train LLaMa model") - default = pyllmq.TrainingConfig() - - # Model configuration - parser.add_argument("--model", default=default.model, help="Path to model directory or HuggingFace model name") - parser.add_argument("--from-scratch", action="store_true", help="Train from random initialization") - parser.add_argument("--init-proj-to-zero", action="store_true", help="Initialize projections to zero") - parser.add_argument("--model-dtype", default=default.model_dtype, help="Model dtype") - parser.add_argument("--matmul-dtype", help="Matmul dtype (defaults to model-dtype)") - parser.add_argument("--gradient-dtype", help="Gradient dtype (defaults to matmul-dtype, except e4m3 matmul uses m5m2 gradients)") - - # Batch configuration - parser.add_argument("--batch-size", "--batch", type=int, default=default.batch_size, help="Micro-batch size") - parser.add_argument("--seq-len", "--seq-length", type=int, default=default.seq_len, help="Sequence length") - parser.add_argument("--grad-accumulation", type=int, default=default.grad_accumulation, help="Gradient accumulation steps") - parser.add_argument("--lmhead-chunks", type=int, default=default.lmhead_chunks, help="Run LM-head in smaller chunks") - parser.add_argument("--attn-bwd-chunks", type=int, default=default.attn_bwd_chunks, help="Run attention backward in smaller chunks") - - # Optimizer - parser.add_argument("--learning-rate", "--lr", type=float, default=default.learning_rate, help="Learning rate") - parser.add_argument("--warmup", type=int, default=default.warmup_steps, dest="warmup_steps", help="Warmup steps") - parser.add_argument("--final-lr-fraction", type=float, default=default.final_lr_fraction, help="Final LR fraction") - parser.add_argument("--beta-1", type=float, default=default.beta_1, help="Adam beta 1") - parser.add_argument("--beta-2", type=float, default=default.beta_2, help="Adam beta 2") - parser.add_argument("--opt-m-dtype", default=default.opt_m_dtype, help="First-order momentum dtype") - parser.add_argument("--opt-v-dtype", default=default.opt_v_dtype, help="Second-order momentum dtype") - parser.add_argument("--grad-clip", type=float, default=default.grad_clip, help="Gradient clipping") - parser.add_argument("--weight-decay", type=float, default=default.weight_decay, help="Weight decay") - - # Training - parser.add_argument("--steps", type=int, default=default.steps, help="Training steps") - parser.add_argument("--eval-every-n-steps", type=int, default=default.eval_every, dest="eval_every", help="Evaluation interval") - parser.add_argument("--eval-num-steps", type=int, default=default.eval_num_steps, help="Number of eval batches") - parser.add_argument("--log-gpu-util", type=int, default=default.log_gpu_util, help="GPU logging interval (0 to disable)") - - # Data - parser.add_argument("--train-file", default=default.train_file, help="Training data file") - parser.add_argument("--eval-file", default=default.eval_file, help="Evaluation data file") - - # Output - parser.add_argument("--out-dir", default=default.out_dir, help="Output directory") - parser.add_argument("--checkpoint-dir", default=default.checkpoint_dir, help="Checkpoint directory") - parser.add_argument("--log-file", default=default.log_file, help="Log file") - parser.add_argument("--ckpt-interval", type=int, default=default.ckpt_interval, help="Checkpoint interval") - parser.add_argument("--ckpt-keep-n", type=int, default=default.ckpt_keep_n, help="Number of checkpoints to keep") - parser.add_argument("--ckpt-major", type=int, default=default.ckpt_major, help="Major checkpoint interval") - parser.add_argument("--continue", dest="continue_from_checkpoint", action="store_true", - help="Continue from checkpoint") - - # Multi-GPU - parser.add_argument("--gpus", type=int, default=pyllmq.get_num_gpus(), help="Number of GPUs") - - # Memory optimization - parser.add_argument("--recompute-swiglu", action="store_true", help="Recompute SwiGLU") - parser.add_argument("--recompute-norm", action="store_true", help="Recompute RMSNorm") - parser.add_argument("--recompute-ffn", action="store_true", help="Recompute FFN") - parser.add_argument("--recompute-qkv", action="store_true", help="Recompute QKV") - parser.add_argument("--recompute-att", action="store_true", help="Recompute attention") - parser.add_argument("--recompute-block", action="store_true", help="Recompute entire block") - - # Distributed training - parser.add_argument("--zero-level", type=int, default=1, help="ZeRO optimization level (1-3)") - parser.add_argument("--shard-weights", action="store_true", help="Shard weights across GPUs") - parser.add_argument("--shard-gradients", action="store_true", help="Shard gradients across GPUs") - parser.add_argument("--offload-master", action="store_true", help="Offload master weights to CPU") - parser.add_argument("--offload-quants", action="store_true", help="Offload quantized weights") - parser.add_argument("--offload-opt-m", action="store_true", help="Offload first-order momentum") - parser.add_argument("--offload-opt-v", action="store_true", help="Offload second-order momentum") - parser.add_argument("--persistent-quants", action="store_true", help="Keep quantized weights") + pyllmq.add_training_args(parser) def add_toggle(arg: str, default: bool, help: str): dest = arg.replace("-", "_") @@ -164,19 +97,6 @@ def add_toggle(arg: str, default: bool, help: str): add_toggle("all-to-all-reduce", True, "Use custom all-to-all reduce which can be used with memcpy-send-recv") add_toggle("write-combined", False, "Use write-combined memory. May give faster PCIe transfers.") - # Logging - parser.add_argument("-qq", "--silent", dest="verbosity", action="store_const", const=pyllmq.LogVerbosity.SILENT, - help="Silent mode (no output)") - parser.add_argument("-q", "--quiet", dest="verbosity", action="store_const", const=pyllmq.LogVerbosity.QUIET, - help="Quiet mode (minimal output)") - parser.add_argument("-v", "--verbose", dest="verbosity", action="store_const", const=pyllmq.LogVerbosity.VERBOSE, - help="Verbose mode (detailed output)") - parser.set_defaults(verbosity=default.verbosity) - parser.add_argument("--use-wandb", action="store_true", help="Enable Weights & Biases logging") - parser.add_argument("--wandb-project", default=default.wandb_project, help="W&B project name (defaults to 'LLMQ')") - parser.add_argument("--wandb-name", default=default.wandb_name, help="W&B run name") - - args = parser.parse_args() return pyllmq.TrainingConfig(**vars(args)) diff --git a/src/binding/py_train.cpp b/src/binding/py_train.cpp index 9036eab..05174ca 100644 --- a/src/binding/py_train.cpp +++ b/src/binding/py_train.cpp @@ -22,10 +22,15 @@ MultiGPUPyTrainer::MultiGPUPyTrainer(int ngpus, LLamaConfig config, LLamaOptions mThreads = NCCLCommunicator::launch_threads_communicators( ngpus, memcpy_all_gather, memcpy_send_recv, [&](NCCLCommunicator& comm) { - this->main_loop(comm); + try { + this->main_loop(comm); + } catch (...) { + mHasCrashed = true; + throw; + } }); - while(!mIsRunning) { + while(!mIsRunning && !mHasCrashed) { std::this_thread::yield(); } } @@ -207,12 +212,13 @@ void MultiGPUPyTrainer::main_loop(NCCLCommunicator& comm) { ctx.Model = std::make_unique(mConfig, mOptions, comm.rank(), comm.world_size()); ctx.Model->allocate_run_state(mOptions, comm, B, T); - if(mIsReady.fetch_add(1) == comm.world_size() - 1) { + if (mIsReady.fetch_add(1) == comm.world_size() - 1) { mIsRunning = true; }; while (!mIsRunning.load()) { std::this_thread::yield(); + if(mHasCrashed.load()) throw std::runtime_error("Another worker has crashed, exiting."); } while (mIsRunning.load()) { @@ -225,7 +231,7 @@ void MultiGPUPyTrainer::main_loop(NCCLCommunicator& comm) { } CUDA_CHECK(cudaDeviceSynchronize()); comm.barrier(); - + // free resources ctx.Model.reset(); ctx.GPUUtil.reset(); diff --git a/src/binding/py_train.h b/src/binding/py_train.h index 1a8b71b..9ed6e52 100644 --- a/src/binding/py_train.h +++ b/src/binding/py_train.h @@ -85,6 +85,7 @@ class MultiGPUPyTrainer std::vector mContexts; std::mutex mGlobalMutex; std::atomic mIsRunning = false; + std::atomic mHasCrashed = false; std::atomic mIsReady = 0; std::atomic mWorkDone = 0; diff --git a/src/binding/python/__init__.py b/src/binding/python/__init__.py index 3d64d71..acd0bba 100644 --- a/src/binding/python/__init__.py +++ b/src/binding/python/__init__.py @@ -1,2 +1,2 @@ from ._pyllmq import * -from .training import TrainingConfig, CosineLRSchedule, training_logger_context +from .training import TrainingConfig, CosineLRSchedule, training_logger_context, add_training_args diff --git a/src/binding/python/tests/recompute.py b/src/binding/python/tests/recompute.py index 39078c7..e136bb0 100644 --- a/src/binding/python/tests/recompute.py +++ b/src/binding/python/tests/recompute.py @@ -17,10 +17,11 @@ import copy import sys -from pyllmq.tests.run import RunConfig, run_training, parse_args, compare_results +from pyllmq.training import TrainingConfig +from pyllmq.tests.run import run_training, parse_args, compare_results -def disable_recompute(config: RunConfig): +def disable_recompute(config: TrainingConfig): baseline_config = copy.deepcopy(config) baseline_config.recompute_swiglu = False baseline_config.recompute_rms_norm = False @@ -30,9 +31,13 @@ def disable_recompute(config: RunConfig): baseline_config.recompute_block = False baseline_config.use_cuda_graphs = False baseline_config.offload_master = False + baseline_config.offload_quants = False baseline_config.offload_opt_v = False baseline_config.offload_opt_m = False baseline_config.attn_bwd_chunks = 1 + baseline_config.memcpy_all_gather = False + baseline_config.shard_weights = False + baseline_config.persistent_quants = False return baseline_config diff --git a/src/binding/python/tests/run.py b/src/binding/python/tests/run.py index 2a1e1d6..9476f27 100644 --- a/src/binding/python/tests/run.py +++ b/src/binding/python/tests/run.py @@ -4,60 +4,7 @@ from typing import Optional, List import pyllmq import numpy as np - - -@dataclass -class RunConfig: - """Configuration for recomputation testing.""" - # Training hyperparameters - batch_size: int = 2 - seq_len: int = 1024 - grad_accum: int = 4 - lmhead_chunks: int = 1 - attn_bwd_chunks: int = 1 - max_steps: int = 10 - - # Optimizer settings - beta_1: float = 0.9 - beta_2: float = 0.95 - grad_clip: float = 1.0 - weight_decay: float = 0.1 - learning_rate: float = 1e-5 - - # Model settings - model_dtype: str = "bf16" - model: str = "Qwen/Qwen2.5-0.5B" - train_file: str = "data/tiny-shakespeare-qwen/train.bin" - matmul_dtype: Optional[str] = None - gradient_dtype: Optional[str] = None - - # Communication settings - memcpy_all_gather: bool = False - memcpy_send_recv: bool = False - - # Optimizer dtypes - opt_m_dtype: str = "fp32" - opt_v_dtype: str = "fp32" - - # Recomputation options - recompute_swiglu: bool = False - recompute_rms_norm: bool = False - recompute_ffn: bool = False - recompute_qkv: bool = False - recompute_att: bool = False - recompute_block: bool = False - offload_residual: bool = False - use_cuda_graphs: bool = False - - # offloading - offload_master: bool = False - offload_quants: bool = False - offload_opt_m: bool = False - offload_opt_v: bool = False - persistent_quants: bool = False - - # Test settings - seed: int = 0x83b45442 +from pyllmq.training import TrainingConfig @dataclass @@ -106,11 +53,11 @@ def compare_results(result: RunResult, expected: RunResult, *, file=None, atol=0 return passed -def _create_options(config: RunConfig) -> pyllmq.LLamaOptions: +def _create_options(config: TrainingConfig) -> pyllmq.LLamaOptions: """Create LLamaOptions from config.""" options = pyllmq.LLamaOptions() options.recompute_swiglu = config.recompute_swiglu - options.recompute_rms_norm = config.recompute_rms_norm + options.recompute_rms_norm = config.recompute_norm options.recompute_ffn = config.recompute_ffn options.recompute_qkv = config.recompute_qkv options.recompute_att = config.recompute_att @@ -142,7 +89,7 @@ def _create_options(config: RunConfig) -> pyllmq.LLamaOptions: return options -def run_training(config: RunConfig) -> RunResult: +def run_training(config: TrainingConfig) -> RunResult: """Run training with given options and return losses and norms.""" options = _create_options(config) @@ -150,12 +97,12 @@ def run_training(config: RunConfig) -> RunResult: # Create trainer trainer = pyllmq.LLMQTrainer.from_pretrained( name=config.model, - ngpu=1, + ngpu=config.gpus, dtype=config.model_dtype, options=options, batch_size=config.batch_size, seq_len=config.seq_len, - grad_accum=config.grad_accum, + grad_accum=config.grad_accumulation, memcpy_all_gather=config.memcpy_all_gather, memcpy_send_recv=config.memcpy_send_recv ) @@ -163,21 +110,21 @@ def run_training(config: RunConfig) -> RunResult: # Create data loader train_loader = pyllmq.DataLoader( [config.train_file], - config.batch_size * config.seq_len, - seed=config.seed + config.batch_size * config.seq_len * config.gpus, + seed=0x83b45442 ) # Prepare input/output buffers - in_tokens = np.empty((config.batch_size, config.seq_len), dtype=np.int32) - out_tokens = np.empty((config.batch_size, config.seq_len), dtype=np.int32) + in_tokens = np.empty((config.batch_size * config.gpus, config.seq_len), dtype=np.int32) + out_tokens = np.empty((config.batch_size * config.gpus, config.seq_len), dtype=np.int32) losses = [] norms = [] # Training loop - for step in range(config.max_steps): + for step in range(config.steps): # Gradient accumulation loop - for j in range(config.grad_accum): + for j in range(config.grad_accumulation): train_loader.load_batch(in_tokens, out_tokens) trainer.step(in_tokens, out_tokens) @@ -201,100 +148,24 @@ def run_training(config: RunConfig) -> RunResult: return RunResult(losses=losses, norms=norms) -def parse_args(args: list = None) -> RunConfig: +def parse_args(args: list = None) -> TrainingConfig: """Parse command line arguments and return TestConfig.""" parser = argparse.ArgumentParser( description="Test recomputation strategies produce identical results" ) - - parser.add_argument("--model", default=RunConfig.model, - help="Path to HuggingFace model directory or cached model name") - parser.add_argument("--matmul-dtype", default=None, - help="Which dtype to use for matmuls (defaults to model-dtype)") - parser.add_argument("--gradient-dtype", default=None, - help="Which dtype to use for activation gradients (defaults to matmul-dtype)") - parser.add_argument("--model-dtype", default=RunConfig.model_dtype, - help="Which dtype to use for model") - parser.add_argument("--train-file", default=RunConfig.train_file, - help="Tokens for training") - parser.add_argument("--grad-accumulation", type=int, default=RunConfig.grad_accum, - help="Number of micro-batches per optimizer step") - parser.add_argument("--lmhead-chunks", type=int, default=RunConfig.lmhead_chunks, - help="Number of chunks for the lm-head") - parser.add_argument("--attn-bwd-chunks", type=int, default=RunConfig.attn_bwd_chunks, - help="Number of chunks for attention backward") - - # Recomputation options - parser.add_argument("--recompute-swiglu", action="store_true", - help="Recompute swiglu during backward pass") - parser.add_argument("--recompute-norm", action="store_true", - help="Recompute rms-norms during backward pass") - parser.add_argument("--recompute-ffn", action="store_true", - help="Recompute feed-forward block during backward pass") - parser.add_argument("--recompute-qkv", action="store_true", - help="Recompute qkv projections during backward pass") - parser.add_argument("--recompute-att", action="store_true", - help="Recompute attention block during backward pass") - parser.add_argument("--recompute-block", action="store_true", - help="Recompute entire transformer block") - parser.add_argument("--offload-residual", action="store_true", - help="Offload residual activations") - parser.add_argument("--use-cuda-graphs", action="store_true", - help="Enable CUDA graphs") - - parser.add_argument("--offload-master", action="store_true", help="Offload master weights to CPU") - parser.add_argument("--offload-quants", action="store_true", help="Offload quantized weights") - parser.add_argument("--offload-opt-m", action="store_true", help="Offload first-order momentum") - parser.add_argument("--offload-opt-v", action="store_true", help="Offload second-order momentum") - parser.add_argument("--persistent-quants", action="store_true", help="Keep quantized weights") - - # Optional parameters - parser.add_argument("--batch-size", "--batch", type=int, default=RunConfig.batch_size, - help="Micro-batch size") - parser.add_argument("--seq-len", "--seq-length", type=int, default=RunConfig.seq_len, - help="Sequence length") - parser.add_argument("--steps", type=int, default=RunConfig.max_steps, - help="Number of training steps") - parser.add_argument("--beta-1", type=float, default=RunConfig.beta_1, - help="Beta 1 for Adam") - parser.add_argument("--beta-2", type=float, default=RunConfig.beta_2, - help="Beta 2 for Adam") - parser.add_argument("--opt-m-dtype", default=RunConfig.opt_m_dtype, - help="DType for first-order momentum") - parser.add_argument("--opt-v-dtype", default=RunConfig.opt_v_dtype, - help="DType for second-order momentum") - parser.add_argument("--grad-clip", type=float, default=RunConfig.grad_clip, - help="Gradient clipping") - parser.add_argument("--weight-decay", type=float, default=RunConfig.weight_decay, - help="Weight decay for matrix parameters") - parser.add_argument("--learning-rate", "--lr", type=float, default=RunConfig.learning_rate, - help="Learning rate") - + from pyllmq.training import add_training_args + default = TrainingConfig() + default.steps = 10 + default.train_file = "data/tiny-shakespeare-qwen/train.bin" + default.batch_size = 2 + add_training_args(parser, default=default) + parser.add_argument("--use-cuda-graphs", action="store_true") + parser.add_argument("--memcpy-all-gather", action="store_true") + parser.add_argument("--memcpy-send-recv", action="store_true") + parser.add_argument("--all-to-all-reduce", action="store_true") + parser.add_argument("--write-combined", action="store_true") args = parser.parse_args(args=args) - return RunConfig( - model=args.model, - model_dtype=args.model_dtype, - train_file=args.train_file, - grad_accum=args.grad_accumulation, - batch_size=args.batch_size, - seq_len=args.seq_len, - max_steps=args.steps, - beta_1=args.beta_1, - beta_2=args.beta_2, - grad_clip=args.grad_clip, - weight_decay=args.weight_decay, - learning_rate=args.learning_rate, - matmul_dtype=args.matmul_dtype, - gradient_dtype=args.gradient_dtype, - opt_m_dtype=args.opt_m_dtype, - opt_v_dtype=args.opt_v_dtype, - recompute_swiglu=args.recompute_swiglu, - recompute_rms_norm=args.recompute_norm, - recompute_ffn=args.recompute_ffn, - recompute_qkv=args.recompute_qkv, - recompute_att=args.recompute_att, - recompute_block=args.recompute_block, - offload_residual=args.offload_residual, - use_cuda_graphs=args.use_cuda_graphs, - ) + cfg = TrainingConfig(**vars(args)) + cfg.eval_file = cfg.train_file + return cfg diff --git a/src/binding/python/tests/torch_reference.py b/src/binding/python/tests/torch_reference.py index 328ff3a..01bd00a 100644 --- a/src/binding/python/tests/torch_reference.py +++ b/src/binding/python/tests/torch_reference.py @@ -9,16 +9,17 @@ import transformers import pyllmq -from pyllmq.tests.run import RunConfig, parse_args, _create_options +from pyllmq.tests.run import parse_args, _create_options +from pyllmq.training import TrainingConfig -def torch_grad_one_step(config: RunConfig): +def torch_grad_one_step(config: TrainingConfig): torch_model = transformers.AutoModelForCausalLM.from_pretrained(config.model, device_map="cuda", torch_dtype=torch.float32) data_loader = pyllmq.DataLoader( [config.train_file], config.batch_size * config.seq_len, - seed=config.seed + seed=0x83b45442 ) in_tokens = np.empty((config.batch_size, config.seq_len), dtype=np.int32) @@ -27,14 +28,14 @@ def torch_grad_one_step(config: RunConfig): result = {} torch_model.zero_grad() - for j in range(config.grad_accum): + for j in range(config.grad_accumulation): data_loader.load_batch(in_tokens, out_tokens) logits = torch_model(torch.tensor(in_tokens).to("cuda")).logits loss = torch.nn.functional.cross_entropy(logits.reshape(-1, logits[0].shape[-1]).to(torch.float32), torch.tensor(out_tokens).to("cuda").to(torch.int64).reshape(-1), reduction="none") loss = loss.reshape(out_tokens.shape) - loss = loss.sum() / (out_tokens.shape[0] * out_tokens.shape[1] * config.grad_accum) + loss = loss.sum() / (out_tokens.shape[0] * out_tokens.shape[1] * config.grad_accumulation) loss.backward() for k, v in torch_model.named_parameters(): @@ -43,18 +44,18 @@ def torch_grad_one_step(config: RunConfig): return result -def llmq_grad_one_step(config: RunConfig): +def llmq_grad_one_step(config: TrainingConfig): options = _create_options(config) # Create trainer trainer = pyllmq.LLMQTrainer.from_pretrained( name=config.model, - ngpu=1, + ngpu=config.gpus, dtype=config.model_dtype, options=options, batch_size=config.batch_size, seq_len=config.seq_len, - grad_accum=config.grad_accum, + grad_accum=config.grad_accumulation, memcpy_all_gather=config.memcpy_all_gather, memcpy_send_recv=config.memcpy_send_recv ) @@ -63,21 +64,21 @@ def llmq_grad_one_step(config: RunConfig): train_loader = pyllmq.DataLoader( [config.train_file], config.batch_size * config.seq_len, - seed=config.seed + seed=0x83b45442 ) # Prepare input/output buffers in_tokens = np.empty((config.batch_size, config.seq_len), dtype=np.int32) out_tokens = np.empty((config.batch_size, config.seq_len), dtype=np.int32) - for j in range(config.grad_accum): + for j in range(config.grad_accumulation): train_loader.load_batch(in_tokens, out_tokens) trainer.step(in_tokens, out_tokens) return {k: torch.from_dlpack(v).cpu().to(torch.float32).numpy() for k, v in trainer.get_gradients(0).items()} def compare_single_step(config, file=None): - config.max_steps = 1 + config.steps = 1 torch_grads = torch_grad_one_step(config) torch.cuda.empty_cache() diff --git a/src/binding/python/training.py b/src/binding/python/training.py index 6574c64..d3a235e 100644 --- a/src/binding/python/training.py +++ b/src/binding/python/training.py @@ -1,3 +1,4 @@ +import argparse import contextlib import sys from dataclasses import dataclass, asdict @@ -69,6 +70,7 @@ class TrainingConfig: recompute_qkv: bool = False recompute_att: bool = False recompute_block: bool = False + offload_residual: bool = False # Distributed training options zero_level: int = 1 @@ -86,6 +88,7 @@ class TrainingConfig: memcpy_send_recv: bool = False all_to_all_reduce: bool = False write_combined: bool = False + use_zero_copy: bool = False # Logging verbosity verbosity: str = _pyllmq.LogVerbosity.DEFAULT @@ -96,6 +99,91 @@ class TrainingConfig: wandb_name: str = "llmq" +def add_training_args(parser: argparse.ArgumentParser, default: Optional[TrainingConfig] = None): + default = TrainingConfig() if default is None else default + + # Model configuration + parser.add_argument("--model", default=default.model, help="Path to model directory or HuggingFace model name") + parser.add_argument("--from-scratch", action="store_true", help="Train from random initialization") + parser.add_argument("--init-proj-to-zero", action="store_true", help="Initialize projections to zero") + parser.add_argument("--model-dtype", default=default.model_dtype, help="Model dtype") + parser.add_argument("--matmul-dtype", help="Matmul dtype (defaults to model-dtype)") + parser.add_argument("--gradient-dtype", help="Gradient dtype (defaults to matmul-dtype, except e4m3 matmul uses m5m2 gradients)") + + # Batch configuration + parser.add_argument("--batch-size", "--batch", type=int, default=default.batch_size, help="Micro-batch size") + parser.add_argument("--seq-len", "--seq-length", type=int, default=default.seq_len, help="Sequence length") + parser.add_argument("--grad-accumulation", type=int, default=default.grad_accumulation, help="Gradient accumulation steps") + parser.add_argument("--lmhead-chunks", type=int, default=default.lmhead_chunks, help="Run LM-head in smaller chunks") + parser.add_argument("--attn-bwd-chunks", type=int, default=default.attn_bwd_chunks, help="Run attention backward in smaller chunks") + + # Optimizer + parser.add_argument("--learning-rate", "--lr", type=float, default=default.learning_rate, help="Learning rate") + parser.add_argument("--warmup", type=int, default=default.warmup_steps, dest="warmup_steps", help="Warmup steps") + parser.add_argument("--final-lr-fraction", type=float, default=default.final_lr_fraction, help="Final LR fraction") + parser.add_argument("--beta-1", type=float, default=default.beta_1, help="Adam beta 1") + parser.add_argument("--beta-2", type=float, default=default.beta_2, help="Adam beta 2") + parser.add_argument("--opt-m-dtype", default=default.opt_m_dtype, help="First-order momentum dtype") + parser.add_argument("--opt-v-dtype", default=default.opt_v_dtype, help="Second-order momentum dtype") + parser.add_argument("--grad-clip", type=float, default=default.grad_clip, help="Gradient clipping") + parser.add_argument("--weight-decay", type=float, default=default.weight_decay, help="Weight decay") + + # Training + parser.add_argument("--steps", type=int, default=default.steps, help="Training steps") + parser.add_argument("--eval-every-n-steps", type=int, default=default.eval_every, dest="eval_every", help="Evaluation interval") + parser.add_argument("--eval-num-steps", type=int, default=default.eval_num_steps, help="Number of eval batches") + parser.add_argument("--log-gpu-util", type=int, default=default.log_gpu_util, help="GPU logging interval (0 to disable)") + + # Data + parser.add_argument("--train-file", default=default.train_file, help="Training data file") + parser.add_argument("--eval-file", default=default.eval_file, help="Evaluation data file") + + # Output + parser.add_argument("--out-dir", default=default.out_dir, help="Output directory") + parser.add_argument("--checkpoint-dir", default=default.checkpoint_dir, help="Checkpoint directory") + parser.add_argument("--log-file", default=default.log_file, help="Log file") + parser.add_argument("--ckpt-interval", type=int, default=default.ckpt_interval, help="Checkpoint interval") + parser.add_argument("--ckpt-keep-n", type=int, default=default.ckpt_keep_n, help="Number of checkpoints to keep") + parser.add_argument("--ckpt-major", type=int, default=default.ckpt_major, help="Major checkpoint interval") + parser.add_argument("--continue", dest="continue_from_checkpoint", action="store_true", + help="Continue from checkpoint") + + # Multi-GPU + parser.add_argument("--gpus", type=int, default=_pyllmq.get_num_gpus(), help="Number of GPUs") + + # Memory optimization + parser.add_argument("--recompute-swiglu", action="store_true", help="Recompute SwiGLU") + parser.add_argument("--recompute-norm", action="store_true", help="Recompute RMSNorm") + parser.add_argument("--recompute-ffn", action="store_true", help="Recompute FFN") + parser.add_argument("--recompute-qkv", action="store_true", help="Recompute QKV") + parser.add_argument("--recompute-att", action="store_true", help="Recompute attention") + parser.add_argument("--recompute-block", action="store_true", help="Recompute entire block") + parser.add_argument("--offload-residual", action="store_true", help="Offload residual activations") + + # Distributed training + parser.add_argument("--zero-level", type=int, default=1, help="ZeRO optimization level (1-3)") + parser.add_argument("--shard-weights", action="store_true", help="Shard weights across GPUs") + parser.add_argument("--shard-gradients", action="store_true", help="Shard gradients across GPUs") + parser.add_argument("--offload-master", action="store_true", help="Offload master weights to CPU") + parser.add_argument("--offload-quants", action="store_true", help="Offload quantized weights") + parser.add_argument("--offload-opt-m", action="store_true", help="Offload first-order momentum") + parser.add_argument("--offload-opt-v", action="store_true", help="Offload second-order momentum") + parser.add_argument("--persistent-quants", action="store_true", help="Keep quantized weights") + parser.add_argument("--use-zero-copy", action="store_true", help="Use zero-copy DMA for offloaded optimizer states") + + # Logging + parser.add_argument("-qq", "--silent", dest="verbosity", action="store_const", const=_pyllmq.LogVerbosity.SILENT, + help="Silent mode (no output)") + parser.add_argument("-q", "--quiet", dest="verbosity", action="store_const", const=_pyllmq.LogVerbosity.QUIET, + help="Quiet mode (minimal output)") + parser.add_argument("-v", "--verbose", dest="verbosity", action="store_const", const=_pyllmq.LogVerbosity.VERBOSE, + help="Verbose mode (detailed output)") + parser.set_defaults(verbosity=default.verbosity) + parser.add_argument("--use-wandb", action="store_true", help="Enable Weights & Biases logging") + parser.add_argument("--wandb-project", default=default.wandb_project, help="W&B project name (defaults to 'LLMQ')") + parser.add_argument("--wandb-name", default=default.wandb_name, help="W&B run name") + + class CosineLRSchedule: """Cosine learning rate schedule with linear warmup.""" def __init__(self, base_lr: float, max_steps: int, warmup_steps: int, final_lr: float): diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index 4e32f6e..3735d0c 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -389,6 +389,9 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm, Grads->notify_lnf_w(main_stream, comm); rs->fetch_res_ffn(L-2, comm.stream()); Parameters->gather_block(L - 1, comm, *rs); + + cudaEvent_t test_event; + CUDA_CHECK(cudaEventCreate(&test_event)); // now backward all the layers for (int l = L-1; l >= 0; l--) { NvtxRange layer_range("Layer", l); @@ -407,9 +410,11 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm, Tensor residual = l == 0 ? rs->Encoded : rs->get_res_ffn(l - 1, main_stream); trace_or_execute_cuda_graph([&]() { _recompute_block(weights, rs->Acts[l], residual); + if(last_step) { + CUDA_CHECK(cudaStreamWaitEvent(main_stream, test_event, 0)); + } _backward_block(accumulate, weights, dw, rs->Acts[l], rs->DActs[l]); }, main_stream, rs->BackwardBlockGraph, rs->Options.UseCudaGraphs); - if(l > 0) { auto& prev_dacts = rs->DActs.at(l - 1); rmsnorm_backward(prev_dacts.DResFFN.Value, dw.LN1_w, rs->RMSNormScratch, prev_dacts.DResAtt.Value, d_acts.DLN1, @@ -422,6 +427,7 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm, } Parameters->release_block(l, main_stream); Grads->notify_block(l, main_stream, comm); + CUDA_CHECK(cudaEventRecord(test_event, comm.stream())); } auto& d_emb = Grads->get_embeddings_full(main_stream, comm, accumulate);