diff --git a/week06_fsdp/practice.ipynb b/week06_fsdp/practice.ipynb new file mode 100644 index 0000000..f8ecc09 --- /dev/null +++ b/week06_fsdp/practice.ipynb @@ -0,0 +1,495 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-09T15:28:03.083391Z", + "iopub.status.busy": "2025-02-09T15:28:03.082960Z", + "iopub.status.idle": "2025-02-09T15:28:12.294769Z", + "shell.execute_reply": "2025-02-09T15:28:12.293484Z", + "shell.execute_reply.started": "2025-02-09T15:28:03.083356Z" + }, + "trusted": true + }, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "sudo rm -rf torchtitan\n", + "git clone -q https://github.com/pytorch/torchtitan\n", + "git -C torchtitan checkout -q 49c6d6fc15ef644e5c3b1003ad4e0d9ea5fcb9a9\n", + "curl -s https://gist.githubusercontent.com/antony-frolov/c2e69bbda2b4418b1ab1c99839c55877/raw/c873709f6fe34dbf8ba678302e4fa92d6ed8c7f1/1b.patch -o 1b.patch\n", + "patch -s -p1 -i ../1b.patch -d torchtitan\n", + "sudo pip install -q fire triton -r ./torchtitan/requirements.txt ./torchtitan\n", + "sudo apt-get update -qq && sudo apt-get install -qq pciutils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-09T16:08:10.559856Z", + "iopub.status.busy": "2025-02-09T16:08:10.559432Z", + "iopub.status.idle": "2025-02-09T16:08:10.571441Z", + "shell.execute_reply": "2025-02-09T16:08:10.570076Z", + "shell.execute_reply.started": "2025-02-09T16:08:10.559821Z" + }, + "trusted": true + }, + "outputs": [], + "source": [ + "%%writefile train.py\n", + "import functools\n", + "import os\n", + "import pickle\n", + "import time\n", + "from typing import Optional\n", + "\n", + "import fire\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.distributed import DeviceMesh, init_device_mesh\n", + "from torch.distributed._composable.fsdp import (\n", + " CPUOffloadPolicy,\n", + " MixedPrecisionPolicy,\n", + " fully_shard,\n", + ")\n", + "from torch.optim.lr_scheduler import LambdaLR\n", + "\n", + "import torchtitan.utils as utils\n", + "from torchtitan.datasets import build_hf_data_loader, build_tokenizer\n", + "from torchtitan.logging import init_logger, logger\n", + "from torchtitan.metrics import build_device_memory_monitor\n", + "from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config\n", + "from torchtitan.optimizer import linear_warmup_linear_decay\n", + "\n", + "\n", + "def trace_handler(prof, trace_dir: str):\n", + " curr_trace_dir_name = \"iteration_\" + str(prof.step_num)\n", + " curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)\n", + " if not os.path.exists(curr_trace_dir):\n", + " os.makedirs(curr_trace_dir, exist_ok=True)\n", + "\n", + " logger.info(f\"Dumping profiler traces at step {prof.step_num}\")\n", + " begin = time.monotonic()\n", + " prof.export_chrome_trace(\n", + " f\"{curr_trace_dir}/rank{torch.distributed.get_rank()}_trace.json\"\n", + " )\n", + " logger.info(\n", + " f\"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds\"\n", + " )\n", + "\n", + "\n", + "class MemoryProfiler:\n", + " def __init__(\n", + " self,\n", + " step_num: int,\n", + " freq: int,\n", + " snapshot_dir: str,\n", + " dir_name: Optional[str] = None,\n", + " ):\n", + " self.snapshot_dir = snapshot_dir\n", + " if not os.path.exists(snapshot_dir):\n", + " os.makedirs(snapshot_dir, exist_ok=True)\n", + "\n", + " # when resume training, we start from the last step\n", + " self.step_num = step_num\n", + " self.freq = freq\n", + "\n", + " self.dir_name = dir_name\n", + "\n", + " def step(self):\n", + " self.step_num += 1\n", + " if self.step_num % self.freq not in [0, self.freq - 1]:\n", + " return\n", + " if self.step_num % self.freq == self.freq - 1:\n", + " torch.cuda.memory._record_memory_history()\n", + " return\n", + " curr_step = self.step_num\n", + " if self.dir_name is None:\n", + " dir_name = f\"iteration_{curr_step}\"\n", + " else:\n", + " dir_name = self.dir_name\n", + " curr_snapshot_dir = os.path.join(self.snapshot_dir, dir_name)\n", + " if not os.path.exists(curr_snapshot_dir):\n", + " os.makedirs(curr_snapshot_dir, exist_ok=True)\n", + " logger.info(f\"Dumping memory snapshot at step {curr_step}\")\n", + " begin = time.monotonic()\n", + " with open(\n", + " f\"{curr_snapshot_dir}/rank{torch.distributed.get_rank()}_memory_snapshot.pickle\",\n", + " \"wb\",\n", + " ) as output:\n", + " pickle.dump(torch.cuda.memory._snapshot(), output)\n", + " torch.cuda.memory._record_memory_history(None)\n", + " logger.info(\n", + " f\"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds\"\n", + " )\n", + "\n", + "\n", + "def apply_fsdp(\n", + " model: nn.Module,\n", + " dp_mesh: DeviceMesh,\n", + " param_dtype: torch.dtype,\n", + " reduce_dtype: torch.dtype,\n", + " cpu_offload: bool,\n", + " reshard_after_forward: bool,\n", + "):\n", + " mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)\n", + " fsdp_config = {\"mesh\": dp_mesh, \"mp_policy\": mp_policy}\n", + " if cpu_offload:\n", + " fsdp_config[\"offload_policy\"] = CPUOffloadPolicy()\n", + "\n", + " for layer_id, transformer_block in model.layers.items():\n", + " fully_shard(\n", + " transformer_block,\n", + " **fsdp_config,\n", + " reshard_after_forward=reshard_after_forward,\n", + " )\n", + " fully_shard(model, **fsdp_config)\n", + "\n", + "\n", + "def train(\n", + " lr: float = 8e-4,\n", + " max_norm: float = 1.0,\n", + " training_steps: int = 10,\n", + " warmup_steps: int = 2,\n", + " batch_size: int = 8,\n", + " seq_len: int = 2048,\n", + " model_name: str = \"llama3\",\n", + " flavor: str = \"debugmodel\",\n", + " norm_type: str = \"rmsnorm\",\n", + " enable_cpu_offload: bool = False,\n", + " param_dtype: str = \"float32\",\n", + " reduce_dtype: str = \"float32\",\n", + " reshard_after_forward: bool = True,\n", + " reshard_after_forward_degree: int | None = None,\n", + " device_type: str = \"cuda\",\n", + " log_freq: int = 1,\n", + " gc_freq: int = 50,\n", + " profile_freq: int = 10,\n", + " profile_active: int = 1,\n", + " profile_warmup: int = 3,\n", + " dump_folder: str = \".\",\n", + " save_traces_folder: str = \"profile_trace\",\n", + " save_memory_snapshot_folder: str = \"memory_snapshot\",\n", + " apply_compile: bool = False,\n", + " num_gas_steps: int = 1,\n", + " reshard_after_backward: bool = True,\n", + " reduce_grads: bool = True,\n", + "):\n", + " decay_steps = training_steps - warmup_steps\n", + " param_dtype = getattr(torch, param_dtype)\n", + " reduce_dtype = getattr(torch, reduce_dtype)\n", + " if reshard_after_forward_degree is not None:\n", + " assert reshard_after_forward\n", + " reshard_after_forward = reshard_after_forward_degree\n", + "\n", + " init_logger()\n", + "\n", + " # take control of garbage collection to avoid stragglers\n", + " gc_handler = utils.GarbageCollection(gc_freq=gc_freq)\n", + "\n", + " # init distributed\n", + " world_size = int(os.environ[\"WORLD_SIZE\"])\n", + " device = torch.device(f\"{device_type}:{int(os.environ['LOCAL_RANK'])}\")\n", + " torch.cuda.set_device(device)\n", + " if not torch.distributed.is_initialized():\n", + " torch.distributed.init_process_group(\"cuda:nccl,cpu:gloo\")\n", + " # initialize device memory monitor and get peak flops for MFU calculation\n", + " device_memory_monitor = build_device_memory_monitor()\n", + " gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)\n", + " logger.info(f\"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}\")\n", + "\n", + " # build meshes\n", + " world_mesh = init_device_mesh(device_type, (world_size,), mesh_dim_names=(\"dp\",))\n", + " dp_mesh = world_mesh[\"dp\"]\n", + " dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()\n", + "\n", + " # build tokenizer\n", + " tokenizer_type = model_name_to_tokenizer[model_name]\n", + " tokenizer = build_tokenizer(\n", + " tokenizer_type, \"torchtitan/tests/assets/test_tiktoken.model\"\n", + " )\n", + " # build dataloader\n", + " data_loader = build_hf_data_loader(\n", + " \"c4_test\",\n", + " \"torchtitan/tests/assets/c4_test\",\n", + " tokenizer,\n", + " batch_size=batch_size,\n", + " seq_len=seq_len,\n", + " world_size=dp_degree,\n", + " rank=dp_rank,\n", + " )\n", + "\n", + " # build model (using meta init)\n", + " model_cls = model_name_to_cls[model_name]\n", + " model_config = models_config[model_name][flavor]\n", + " model_config.norm_type = norm_type\n", + " model_config.vocab_size = tokenizer.n_words\n", + " model_config.max_seq_len = seq_len\n", + "\n", + " logger.info(f\"Building {model_name} {flavor} with {model_config}\")\n", + " memory_profiler = MemoryProfiler(\n", + " profile_freq - 2,\n", + " profile_freq,\n", + " snapshot_dir=os.path.join(dump_folder, save_memory_snapshot_folder),\n", + " dir_name=\"model_init\",\n", + " )\n", + " memory_profiler.step()\n", + " with torch.device(\"meta\"):\n", + " model = model_cls.from_model_args(model_config)\n", + "\n", + " # log model size\n", + " model_param_count = utils.get_num_params(model)\n", + " num_flop_per_token = utils.get_num_flop_per_token(\n", + " utils.get_num_params(model, exclude_embedding=True),\n", + " model_config,\n", + " seq_len,\n", + " )\n", + " logger.info(\n", + " f\"Model {model_name} {flavor} \" f\"size: {model_param_count:,} total parameters\"\n", + " )\n", + "\n", + " # loss function\n", + " def loss_fn(pred, labels):\n", + " return torch.nn.functional.cross_entropy(\n", + " pred.flatten(0, 1).float(), labels.flatten(0, 1)\n", + " )\n", + "\n", + " # move sharded model to CPU/GPU and initialize weights via DTensor\n", + " if enable_cpu_offload:\n", + " init_device = \"cpu\"\n", + " buffer_device = device_type\n", + " else:\n", + " init_device = device_type\n", + " buffer_device = None\n", + "\n", + " # apply parallelisms and initialization\n", + " if apply_compile:\n", + " for layer_id, transformer_block in model.layers.named_children():\n", + " transformer_block = torch.compile(transformer_block, fullgraph=True)\n", + " model.layers.register_module(layer_id, transformer_block)\n", + " logger.info(\"Compiling each TransformerBlock with torch.compile\")\n", + " apply_fsdp(\n", + " model,\n", + " dp_mesh=dp_mesh,\n", + " param_dtype=param_dtype,\n", + " reduce_dtype=reduce_dtype,\n", + " cpu_offload=enable_cpu_offload,\n", + " reshard_after_forward=reshard_after_forward,\n", + " )\n", + " model.to_empty(device=init_device)\n", + " with torch.no_grad():\n", + " model.init_weights(buffer_device=buffer_device)\n", + " model.train()\n", + "\n", + " memory_profiler.step()\n", + "\n", + " device_mem_stats = device_memory_monitor.get_peak_stats()\n", + " logger.info(\n", + " f\"{device_type.upper()} memory usage for model: \"\n", + " f\"{device_mem_stats.max_reserved_gib:.2f}GiB\"\n", + " f\"({device_mem_stats.max_reserved_pct:.2f}%)\"\n", + " )\n", + "\n", + " optimizer = torch.optim.AdamW(\n", + " model.parameters(),\n", + " lr=lr,\n", + " betas=(0.9, 0.95),\n", + " weight_decay=0.1,\n", + " fused=True,\n", + " )\n", + " lr_scheduler = LambdaLR(\n", + " optimizer,\n", + " lr_lambda=functools.partial(\n", + " linear_warmup_linear_decay, warmup_steps, decay_steps\n", + " ),\n", + " )\n", + "\n", + " data_iterator = iter(data_loader)\n", + "\n", + " train_context = utils.get_train_context(\n", + " enable_loss_parallel=False,\n", + " enable_compiled_autograd=False,\n", + " )\n", + "\n", + " # variables used to keep info for metrics logging\n", + " step = 0\n", + " ntokens_since_last_log = 0\n", + " data_loading_times = []\n", + " time_last_log = time.perf_counter()\n", + " device_memory_monitor.reset_peak_stats()\n", + "\n", + " # train loop\n", + " logger.info(\n", + " f\"Training starts at step {step + 1}, \"\n", + " f\"with local batch size {batch_size}, \"\n", + " f\"global batch size {batch_size * dp_degree}, \"\n", + " f\"sequence length {seq_len}, \"\n", + " f\"total steps {training_steps} \"\n", + " f\"(warmup {warmup_steps})\"\n", + " )\n", + " with torch.profiler.profile(\n", + " activities=[\n", + " torch.profiler.ProfilerActivity.CPU,\n", + " torch.profiler.ProfilerActivity.CUDA,\n", + " ],\n", + " schedule=torch.profiler.schedule(\n", + " wait=profile_freq - (profile_active + profile_warmup),\n", + " warmup=profile_warmup,\n", + " active=profile_active,\n", + " ),\n", + " on_trace_ready=functools.partial(\n", + " trace_handler, trace_dir=os.path.join(dump_folder, save_traces_folder)\n", + " ),\n", + " record_shapes=True,\n", + " ) as torch_profiler:\n", + " while step < training_steps:\n", + " memory_profiler = MemoryProfiler(\n", + " step,\n", + " profile_freq,\n", + " snapshot_dir=os.path.join(dump_folder, save_memory_snapshot_folder),\n", + " )\n", + "\n", + " step += 1\n", + " gc_handler.run(step)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " for gas_step in range(num_gas_steps):\n", + " is_last_backward = gas_step == num_gas_steps - 1\n", + " model.set_is_last_backward(is_last_backward)\n", + " model.set_reshard_after_backward(\n", + " reshard_after_backward or is_last_backward\n", + " )\n", + " model.set_requires_gradient_sync(reduce_grads or is_last_backward)\n", + "\n", + " # get batch\n", + " data_load_start = time.perf_counter()\n", + " batch = next(data_iterator)\n", + " input_ids, labels = batch\n", + " ntokens_since_last_log += labels.numel()\n", + " data_loading_times.append(time.perf_counter() - data_load_start)\n", + "\n", + " input_ids = input_ids.to(device_type)\n", + " labels = labels.to(device_type)\n", + "\n", + " # Non-PP forward / backward\n", + " with train_context():\n", + " pred = model(input_ids)\n", + " loss = loss_fn(pred, labels)\n", + " # pred.shape=(bs, seq_len, vocab_size)\n", + " # need to free to before bwd to avoid peaking memory\n", + " del pred\n", + " loss.backward()\n", + "\n", + " # clip gradients\n", + " torch.nn.utils.clip_grad_norm_([p for p in model.parameters()], max_norm)\n", + "\n", + " # optimizer step\n", + " optimizer.step()\n", + " lr_scheduler.step()\n", + "\n", + " # log metrics\n", + " if step == 1 or step % log_freq == 0:\n", + " loss = loss.detach()\n", + " global_avg_loss = utils.dist_mean(loss, dp_mesh)\n", + "\n", + " time_delta = time.perf_counter() - time_last_log\n", + "\n", + " # tokens per second per device, abbreviated as tps\n", + " tps = ntokens_since_last_log / time_delta\n", + " # model FLOPS utilization\n", + " # For its definition and calculation, please refer to the PaLM paper:\n", + " # https://arxiv.org/abs/2204.02311\n", + " mfu = 100 * num_flop_per_token * tps / gpu_peak_flops\n", + "\n", + " device_mem_stats = device_memory_monitor.get_peak_stats()\n", + "\n", + " logger.info(\n", + " f\"step: {step:2} \"\n", + " f\"loss: {global_avg_loss:7.4f} \"\n", + " f\"memory: {device_mem_stats.max_reserved_gib:5.2f}GiB\"\n", + " f\"({device_mem_stats.max_reserved_pct:.2f}%) \"\n", + " f\"tps: {round(tps):,} \"\n", + " f\"mfu: {mfu:.2f}%\"\n", + " )\n", + "\n", + " ntokens_since_last_log = 0\n", + " data_loading_times.clear()\n", + " time_last_log = time.perf_counter()\n", + " device_memory_monitor.reset_peak_stats()\n", + "\n", + " # signal the profiler that the next profiling step has started\n", + " if torch_profiler:\n", + " torch_profiler.step()\n", + " if memory_profiler:\n", + " memory_profiler.step()\n", + "\n", + " logger.info(\"Training completed\")\n", + "\n", + " torch.distributed.destroy_process_group()\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " fire.Fire(train)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!OMP_NUM_THREADS=1 \\\n", + " torchrun \\\n", + " --local-ranks-filter 0 \\\n", + " --nproc-per-node 2 \\\n", + " train.py \\\n", + " --flavor 1B \\\n", + " --batch-size 2 \\\n", + " --seq-len 1024 \\\n", + " --training-steps 20 \\\n", + " --warmup-steps 5 \\\n", + " --gc-freq 5 \\\n", + " --profile-freq 10 \\\n", + " \\\n", + " --param-dtype float16 \\\n", + " --reduce-dtype float16" + ] + } + ], + "metadata": { + "kaggle": { + "accelerator": "none", + "dataSources": [], + "dockerImageVersionId": 30887, + "isGpuEnabled": false, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/week06_fsdp/slides/.gitignore b/week06_fsdp/slides/.gitignore new file mode 100644 index 0000000..ee89780 --- /dev/null +++ b/week06_fsdp/slides/.gitignore @@ -0,0 +1,2 @@ +node_modules +pnpm-lock.yaml diff --git a/week06_fsdp/slides/assets/1b_mem_snap.png b/week06_fsdp/slides/assets/1b_mem_snap.png new file mode 100644 index 0000000..aad630e Binary files /dev/null and b/week06_fsdp/slides/assets/1b_mem_snap.png differ diff --git a/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png b/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png new file mode 100644 index 0000000..6811078 Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png differ diff --git a/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_trace.png b/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_trace.png new file mode 100644 index 0000000..03c59ec Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_trace.png differ diff --git a/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png b/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png new file mode 100644 index 0000000..40a8044 Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png differ diff --git a/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_trace.png b/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_trace.png new file mode 100644 index 0000000..8074c1e Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_trace.png differ diff --git a/week06_fsdp/slides/assets/1b_no_reshard_after_forward_mem_snap.png b/week06_fsdp/slides/assets/1b_no_reshard_after_forward_mem_snap.png new file mode 100644 index 0000000..af476cd Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reshard_after_forward_mem_snap.png differ diff --git a/week06_fsdp/slides/assets/1b_reshard_after_forward_4_mem_snap.png b/week06_fsdp/slides/assets/1b_reshard_after_forward_4_mem_snap.png new file mode 100644 index 0000000..9d402dd Binary files /dev/null and b/week06_fsdp/slides/assets/1b_reshard_after_forward_4_mem_snap.png differ diff --git a/week06_fsdp/slides/assets/1b_reshard_after_forward_4_trace.png b/week06_fsdp/slides/assets/1b_reshard_after_forward_4_trace.png new file mode 100644 index 0000000..bb086b4 Binary files /dev/null and b/week06_fsdp/slides/assets/1b_reshard_after_forward_4_trace.png differ diff --git a/week06_fsdp/slides/assets/8b_compile_backward_trace.png b/week06_fsdp/slides/assets/8b_compile_backward_trace.png new file mode 100644 index 0000000..eea5903 Binary files /dev/null and b/week06_fsdp/slides/assets/8b_compile_backward_trace.png differ diff --git a/week06_fsdp/slides/assets/8b_compile_forward_trace.png b/week06_fsdp/slides/assets/8b_compile_forward_trace.png new file mode 100644 index 0000000..7ee9178 Binary files /dev/null and b/week06_fsdp/slides/assets/8b_compile_forward_trace.png differ diff --git a/week06_fsdp/slides/assets/8b_compile_trace.png b/week06_fsdp/slides/assets/8b_compile_trace.png new file mode 100644 index 0000000..a152cf2 Binary files /dev/null and b/week06_fsdp/slides/assets/8b_compile_trace.png differ diff --git a/week06_fsdp/slides/assets/8b_cpu_offload_mem_snap.png b/week06_fsdp/slides/assets/8b_cpu_offload_mem_snap.png new file mode 100644 index 0000000..dd34750 Binary files /dev/null and b/week06_fsdp/slides/assets/8b_cpu_offload_mem_snap.png differ diff --git a/week06_fsdp/slides/assets/8b_no_compile_iter_mem_snap.png b/week06_fsdp/slides/assets/8b_no_compile_iter_mem_snap.png new file mode 100644 index 0000000..943d70b Binary files /dev/null and b/week06_fsdp/slides/assets/8b_no_compile_iter_mem_snap.png differ diff --git a/week06_fsdp/slides/assets/8b_on_cuda_model_mem_snap.png b/week06_fsdp/slides/assets/8b_on_cuda_model_mem_snap.png new file mode 100644 index 0000000..03e0455 Binary files /dev/null and b/week06_fsdp/slides/assets/8b_on_cuda_model_mem_snap.png differ diff --git a/week06_fsdp/slides/assets/8b_reshard_after_forward_4_trace.png b/week06_fsdp/slides/assets/8b_reshard_after_forward_4_trace.png new file mode 100644 index 0000000..956a461 Binary files /dev/null and b/week06_fsdp/slides/assets/8b_reshard_after_forward_4_trace.png differ diff --git a/week06_fsdp/slides/assets/dcp_1.png b/week06_fsdp/slides/assets/dcp_1.png new file mode 100644 index 0000000..b9f741d Binary files /dev/null and b/week06_fsdp/slides/assets/dcp_1.png differ diff --git a/week06_fsdp/slides/assets/dcp_2.png b/week06_fsdp/slides/assets/dcp_2.png new file mode 100644 index 0000000..154828b Binary files /dev/null and b/week06_fsdp/slides/assets/dcp_2.png differ diff --git a/week06_fsdp/slides/assets/dcp_3.png b/week06_fsdp/slides/assets/dcp_3.png new file mode 100644 index 0000000..fb57a81 Binary files /dev/null and b/week06_fsdp/slides/assets/dcp_3.png differ diff --git a/week06_fsdp/slides/assets/dcp_saving_flow.png b/week06_fsdp/slides/assets/dcp_saving_flow.png new file mode 100644 index 0000000..30dd209 Binary files /dev/null and b/week06_fsdp/slides/assets/dcp_saving_flow.png differ diff --git a/week06_fsdp/slides/assets/device_mesh.png b/week06_fsdp/slides/assets/device_mesh.png new file mode 100644 index 0000000..2ccabcc Binary files /dev/null and b/week06_fsdp/slides/assets/device_mesh.png differ diff --git a/week06_fsdp/slides/assets/dtensor_1.png b/week06_fsdp/slides/assets/dtensor_1.png new file mode 100644 index 0000000..e20d514 Binary files /dev/null and b/week06_fsdp/slides/assets/dtensor_1.png differ diff --git a/week06_fsdp/slides/assets/dtensor_2.png b/week06_fsdp/slides/assets/dtensor_2.png new file mode 100644 index 0000000..0813b8a Binary files /dev/null and b/week06_fsdp/slides/assets/dtensor_2.png differ diff --git a/week06_fsdp/slides/assets/dtensor_3.png b/week06_fsdp/slides/assets/dtensor_3.png new file mode 100644 index 0000000..87f34c5 Binary files /dev/null and b/week06_fsdp/slides/assets/dtensor_3.png differ diff --git a/week06_fsdp/slides/assets/forward_hook.png b/week06_fsdp/slides/assets/forward_hook.png new file mode 100644 index 0000000..8704b82 Binary files /dev/null and b/week06_fsdp/slides/assets/forward_hook.png differ diff --git a/week06_fsdp/slides/assets/forward_pre_hook.png b/week06_fsdp/slides/assets/forward_pre_hook.png new file mode 100644 index 0000000..c177b28 Binary files /dev/null and b/week06_fsdp/slides/assets/forward_pre_hook.png differ diff --git a/week06_fsdp/slides/assets/fsdp_workflow.png b/week06_fsdp/slides/assets/fsdp_workflow.png new file mode 100644 index 0000000..1a8df0e Binary files /dev/null and b/week06_fsdp/slides/assets/fsdp_workflow.png differ diff --git a/week06_fsdp/slides/assets/fsdp_wrap.png b/week06_fsdp/slides/assets/fsdp_wrap.png new file mode 100644 index 0000000..1c1d05a Binary files /dev/null and b/week06_fsdp/slides/assets/fsdp_wrap.png differ diff --git a/week06_fsdp/slides/assets/streams.png b/week06_fsdp/slides/assets/streams.png new file mode 100644 index 0000000..f49a04d Binary files /dev/null and b/week06_fsdp/slides/assets/streams.png differ diff --git a/week06_fsdp/slides/package.json b/week06_fsdp/slides/package.json new file mode 100644 index 0000000..4f6912c --- /dev/null +++ b/week06_fsdp/slides/package.json @@ -0,0 +1,7 @@ +{ + "dependencies": { + "@slidev/cli": "^51.3.0", + "@slidev/theme-default": "^0.25.0", + "playwright-chromium": "^1.50.1" + } +} diff --git a/week06_fsdp/slides/slides.md b/week06_fsdp/slides/slides.md new file mode 100644 index 0000000..cd9cb8d --- /dev/null +++ b/week06_fsdp/slides/slides.md @@ -0,0 +1,605 @@ +--- +title: FSDP семинар +transition: slide-left +--- + +# FSDP семинар + +--- + +# Plan + +
+
+
+ +- Prerequisites: CUDA streams / events, DeviceMesh, DTensor +- FSDP2: interface, options, internals +- PyTorch DCP, efficient garbage collection + +--- + +# CUDA streams and events + +```python +all_gather_stream = torch.cuda.Stream() + +... + +# layer 3 unshard +with torch.cuda.stream(all_gather_stream): + model.layers[3].all_gather() + all_gather_event_3 = torch.cuda.Event() + # or all_gather_stream.record_event() + +# layer 2 forward +activations = model.layers[2](activations) + +# layer 4 unshard +with torch.cuda.stream(all_gather_stream): + model.layers[4].all_gather() + all_gather_event_4 = torch.cuda.Event() + +# layer 3 forward +torch.cuda.default_stream().wait_event(all_gather_event_3) +activations = model.block[3](activations) + +... + +``` + +--- + +# CUDA streams and events + +
+
+
+ + + + + +--- + +# DeviceMesh + + + +--- + +# DeviceMesh + +```python +from torch.distributed.device_mesh import init_device_mesh + +mesh_1d = init_device_mesh("cuda", mesh_shape=(8,), mesh_dim_names=("dp",)) +mesh_2d = init_device_mesh("cpu", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) +mesh_3d = init_device_mesh( + "cuda", + mesh_shape=(2, 2, 8), + mesh_dim_names=("pp", "dp", "tp"), +) + +dp_group = mesh_2d.get_group("dp") +dist.all_gather(..., group=dp_group) + +mesh_2d.get_local_rank("tp") + +mesh_3d["dp", "tp"]._flatten("dp_tp") + +``` + +--- + +# DTensor + +```python +from torch.distributed.tensor import DTensor, distribute_tensor + +mesh = init_device_mesh("cuda", mesh_shape=(8,), mesh_dim_names=("dp",)) +big_tensor = torch.randn(1024, 4096) +placements = (Shard(dim=0),) + +dtensor = distribute_tensor( + big_tensor, + device_mesh=mesh, + placements=placements, +) +dtensor._local_tensor +dtensor.to_local() # .shape = (512, 4096) + + +shard = ... # .shape = (512, 4096) +DTensor.from_local( + shard, + device_mesh=mesh, + placements=placements, +) # .shape = (1024, 4096) + +dtensor.redistribute(placements=(Replicate(),)) +dtensor.full_tensor() +``` + +--- + +# DTensor + + +
+ + +--- + +# DTensor + + + +--- + +# FSDP2 + + + +--- +layout: two-cols-header +layoutClass: gap-5 +--- + +# FSDP2 + +::left:: + +```python +from torch.distributed.fsdp import fully_shard + +mesh_2d = init_device_mesh( + "cuda", + mesh_shape=(2, 8), + mesh_dim_names=("dp", "tp"), +) +model = Model() + +for layer in model.layers: + fully_shard( + module, # (module1, module2) + mesh=dp_mesh, + reshard_after_forward=True, # ZeRO-3 + mp_policy=MixedPrecisionPolicy( + param_dtype=torch.float16, + reduce_dtype=torch.float32, + ), + offload_policy=CPUOffloadPolicy(), + ) + +fully_shard(model, ...) +``` + +::right:: + +```python +for step in ...: + for gas_step in ...: + is_last_backward = gas_step == num_gas_steps - 1 + # ZeRO-3 + model.set_reshard_after_backward(is_last_backward) + # ZeRO-2 + model.requires_gradient_sync(is_last_backward) + + loss = loss_fn(model(inputs), targets) + ... +``` + +--- + +# FSDP2 + + + +--- + +# FSDP2 — hooks + + +
+ + +--- + +# FSDP2 — pre-forward + +```python +def pre_forward(module, args): + module.unshard() # in all-gather stream + module.wait_for_unshard() # sync compute (default) stream with all-gather stream + module._register_post_backward_hook(args) + return args + +def unshard(module): + with torch.cuda.stream(all_gather_stream): + module.all_gather() + module.all_gather_event = all_gather_stream.record_event() + module.set_unsharded_params() + +def wait_for_unshard(module): + torch.cuda.default_stream().wait_event(module.all_gather_event) + +def fully_shard(module, ...): + ... + module.register_forward_pre_hook(pre_forward) +``` + + + +--- + +# FSDP2 — post-forward + +```python +def post_forward(module, args, output): + module.reshard() + module._record_post_forward() + module._register_pre_backward_hook(output) + return output + +def reshard(module): + module.set_sharded_params() # and free unsharded params + +def _record_post_forward(module): + post_forward_index = len(module.comm_ctx.post_forward_order) + module.comm_ctx.post_forward_order.append(module) + module._post_forward_indices.append(post_forward_index) + +def fully_shard(module, ...): + ... + module.register_forward_hook(post_forward) +``` + +--- + +# FSDP2 — pre-backward + +```python +def pre_backward(module, *unused): + module.unshard() # no-op if prefetched + module.wait_for_unshard() + module._backward_prefetch() + +def _backward_prefetch(module): + curr_index = module._post_forward_indices.pop() + target_index = curr_index - 1 + target_module = self.comm_ctx.post_forward_order[target_index] + target_module.unshard() + +def _register_pre_backward_hook(self, output): + for t in output: + if torch.is_tensor(t) and t.requires_grad: + t.register_hook(self._pre_backward) + return output +``` + +--- + +# FSDP2 — post-backward + +```python +def post_backward(module, *unused: Any): + if module.reshard_after_backward: + module.reshard() + if module.reduce_grads: + reduce_scatter_stream.wait_stream(torch.cuda.default_stream()) + with torch.cuda.stream(reduce_scatter_stream): + module.reduce_scatter_grads() + reduce_event = reduce_scatter_stream.record_event() + +def _register_post_backward_hook(module, args): + RegisterPostBackwardFunction.apply(self, *args) + +class RegisterPostBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *inputs): + ctx.module = module + return inputs + + @staticmethod + def backward(ctx, *grads): + module.post_backward() + return (None,) + grads +``` + +--- + +# FSDP2 — memory + + + +--- + +# FSDP2 — memory + + + +--- + +# Computation / communication overlap + +- Implicit prefetching + - в `pre_forward` +- Explicit prefetching + - в `pre_backward` + - можно задать руками + + ```python + module.set_modules_to_forward_prefetch(modules) + module.set_modules_to_backward_prefetch(modules) + ``` + +--- + +# Подробнее про работу со стримами + +
+
+
+
+ + + +
+ +--- + +# Подробнее про работу со стримами — forward + +
+
+
+
+ + + +
+--- + +# Подробнее про работу со стримами — backward + +
+
+
+
+ + + +
+ +--- + +# ZeRO-2 + +
+
+
+
+ + + +
+ +--- + +# ZeRO-2 + + + +--- + +# ZeRO-1 + +
+
+
+
+ + + +
+ +--- + +# ZeRO-1 + + + +--- + +# HSDP + +```python +mesh_2d = init_device_mesh( + "cpu", + mesh_shape=(2, 8), + mesh_dim_names=("dp_replicate", "dp_shard"), +) + +fully_shard( + module, + mesh=mesh_2d, + ... +) +``` + +
+ +- логика становится заметно сложнее, показывать не буду( + +--- + +# CPU offloading + +- [ZeRO-Offload](https://arxiv.org/pdf/2101.06840) + +```python +with torch.device("cpu"): + model = Model() + +fully_shard( + module, + ... + offload_policy=CPUOffloadPolicy(), +) + +def unshard(module): + sharded_param = sharded_param.to( + device, + non_blocking=True, + ) + ... + module.all_gather() + +def post_backward(module): + new_sharded_grad = new_sharded_grad.to( + torch.device("cpu"), + non_blocking=True + ) +``` + +--- + +# CPU offloading + + + +--- + +# hpZ + +- [ZeRO++](https://arxiv.org/pdf/2306.10209) + +```python +mesh = init_device_mesh( + "cuda", + mesh_shape=(16,), + mesh_dim_names=("dp",), +) +fully_shard( + module, + mesh, + ... + reshard_after_forward=8, +) +``` + +--- + +# hpZ + + + +--- + +# hpZ + + + +--- + +# PyTorch DCP + +- два вида `state_dict` + - `SHARDED_STATE_DICT` + - `FULL_STATE_DICT` +- в FSDP2 всегда sharded, но состоит из DTensor-ов + - с помощью `.redistribute()` можно менять шардирование чекпоинта +- DCP умеет эффективно отгружать чекпоинты с минимальным оверхедом + +--- + +# PyTorch DCP + +```python +import torch.distributed.checkpoint as dcp +model = Model() +fully_shard(model) +optimizer = Optimizer(model.parameters()) + +state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict() +} +dcp.state_dict_saver.save(state_dict) +dcp.state_dict_loader.load(state_dict) +``` + +
+ +- [truthfully i's a bit more complicated](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) + +--- + +# PyTorch DCP + + + +--- + +# PyTorch DCP + +
+
+ + + +--- + +# PyTorch DCP + +
+
+ + + +--- + +# PyTorch DCP + +
+
+ + + +--- + +# Garbage collection tuning + +```python +gc.disable() +gc.collect(1) + +... init + +for step in ...: + if step > 1 and step % _gc_freq == 0: + gc.collect(1) + + ... step +``` + +--- + +# Extras + +- [SimpleFSDP](https://arxiv.org/pdf/2411.00284) +- `unshard_in_backward` +- meta device init +- compile + +--- + +# Code + +- можно поиграться со всем этим в [ноутбуке](https://www.kaggle.com/code/antonyfrolov/practice-ipynb) +- пайплайн отладки diff --git a/week06_fsdp/slides/slides.pdf b/week06_fsdp/slides/slides.pdf new file mode 100644 index 0000000..ea77015 Binary files /dev/null and b/week06_fsdp/slides/slides.pdf differ