diff --git a/week06_fsdp/practice.ipynb b/week06_fsdp/practice.ipynb
new file mode 100644
index 0000000..f8ecc09
--- /dev/null
+++ b/week06_fsdp/practice.ipynb
@@ -0,0 +1,495 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T15:28:03.083391Z",
+ "iopub.status.busy": "2025-02-09T15:28:03.082960Z",
+ "iopub.status.idle": "2025-02-09T15:28:12.294769Z",
+ "shell.execute_reply": "2025-02-09T15:28:12.293484Z",
+ "shell.execute_reply.started": "2025-02-09T15:28:03.083356Z"
+ },
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "%%bash\n",
+ "\n",
+ "sudo rm -rf torchtitan\n",
+ "git clone -q https://github.com/pytorch/torchtitan\n",
+ "git -C torchtitan checkout -q 49c6d6fc15ef644e5c3b1003ad4e0d9ea5fcb9a9\n",
+ "curl -s https://gist.githubusercontent.com/antony-frolov/c2e69bbda2b4418b1ab1c99839c55877/raw/c873709f6fe34dbf8ba678302e4fa92d6ed8c7f1/1b.patch -o 1b.patch\n",
+ "patch -s -p1 -i ../1b.patch -d torchtitan\n",
+ "sudo pip install -q fire triton -r ./torchtitan/requirements.txt ./torchtitan\n",
+ "sudo apt-get update -qq && sudo apt-get install -qq pciutils"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T16:08:10.559856Z",
+ "iopub.status.busy": "2025-02-09T16:08:10.559432Z",
+ "iopub.status.idle": "2025-02-09T16:08:10.571441Z",
+ "shell.execute_reply": "2025-02-09T16:08:10.570076Z",
+ "shell.execute_reply.started": "2025-02-09T16:08:10.559821Z"
+ },
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "%%writefile train.py\n",
+ "import functools\n",
+ "import os\n",
+ "import pickle\n",
+ "import time\n",
+ "from typing import Optional\n",
+ "\n",
+ "import fire\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.distributed import DeviceMesh, init_device_mesh\n",
+ "from torch.distributed._composable.fsdp import (\n",
+ " CPUOffloadPolicy,\n",
+ " MixedPrecisionPolicy,\n",
+ " fully_shard,\n",
+ ")\n",
+ "from torch.optim.lr_scheduler import LambdaLR\n",
+ "\n",
+ "import torchtitan.utils as utils\n",
+ "from torchtitan.datasets import build_hf_data_loader, build_tokenizer\n",
+ "from torchtitan.logging import init_logger, logger\n",
+ "from torchtitan.metrics import build_device_memory_monitor\n",
+ "from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config\n",
+ "from torchtitan.optimizer import linear_warmup_linear_decay\n",
+ "\n",
+ "\n",
+ "def trace_handler(prof, trace_dir: str):\n",
+ " curr_trace_dir_name = \"iteration_\" + str(prof.step_num)\n",
+ " curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)\n",
+ " if not os.path.exists(curr_trace_dir):\n",
+ " os.makedirs(curr_trace_dir, exist_ok=True)\n",
+ "\n",
+ " logger.info(f\"Dumping profiler traces at step {prof.step_num}\")\n",
+ " begin = time.monotonic()\n",
+ " prof.export_chrome_trace(\n",
+ " f\"{curr_trace_dir}/rank{torch.distributed.get_rank()}_trace.json\"\n",
+ " )\n",
+ " logger.info(\n",
+ " f\"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds\"\n",
+ " )\n",
+ "\n",
+ "\n",
+ "class MemoryProfiler:\n",
+ " def __init__(\n",
+ " self,\n",
+ " step_num: int,\n",
+ " freq: int,\n",
+ " snapshot_dir: str,\n",
+ " dir_name: Optional[str] = None,\n",
+ " ):\n",
+ " self.snapshot_dir = snapshot_dir\n",
+ " if not os.path.exists(snapshot_dir):\n",
+ " os.makedirs(snapshot_dir, exist_ok=True)\n",
+ "\n",
+ " # when resume training, we start from the last step\n",
+ " self.step_num = step_num\n",
+ " self.freq = freq\n",
+ "\n",
+ " self.dir_name = dir_name\n",
+ "\n",
+ " def step(self):\n",
+ " self.step_num += 1\n",
+ " if self.step_num % self.freq not in [0, self.freq - 1]:\n",
+ " return\n",
+ " if self.step_num % self.freq == self.freq - 1:\n",
+ " torch.cuda.memory._record_memory_history()\n",
+ " return\n",
+ " curr_step = self.step_num\n",
+ " if self.dir_name is None:\n",
+ " dir_name = f\"iteration_{curr_step}\"\n",
+ " else:\n",
+ " dir_name = self.dir_name\n",
+ " curr_snapshot_dir = os.path.join(self.snapshot_dir, dir_name)\n",
+ " if not os.path.exists(curr_snapshot_dir):\n",
+ " os.makedirs(curr_snapshot_dir, exist_ok=True)\n",
+ " logger.info(f\"Dumping memory snapshot at step {curr_step}\")\n",
+ " begin = time.monotonic()\n",
+ " with open(\n",
+ " f\"{curr_snapshot_dir}/rank{torch.distributed.get_rank()}_memory_snapshot.pickle\",\n",
+ " \"wb\",\n",
+ " ) as output:\n",
+ " pickle.dump(torch.cuda.memory._snapshot(), output)\n",
+ " torch.cuda.memory._record_memory_history(None)\n",
+ " logger.info(\n",
+ " f\"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds\"\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def apply_fsdp(\n",
+ " model: nn.Module,\n",
+ " dp_mesh: DeviceMesh,\n",
+ " param_dtype: torch.dtype,\n",
+ " reduce_dtype: torch.dtype,\n",
+ " cpu_offload: bool,\n",
+ " reshard_after_forward: bool,\n",
+ "):\n",
+ " mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)\n",
+ " fsdp_config = {\"mesh\": dp_mesh, \"mp_policy\": mp_policy}\n",
+ " if cpu_offload:\n",
+ " fsdp_config[\"offload_policy\"] = CPUOffloadPolicy()\n",
+ "\n",
+ " for layer_id, transformer_block in model.layers.items():\n",
+ " fully_shard(\n",
+ " transformer_block,\n",
+ " **fsdp_config,\n",
+ " reshard_after_forward=reshard_after_forward,\n",
+ " )\n",
+ " fully_shard(model, **fsdp_config)\n",
+ "\n",
+ "\n",
+ "def train(\n",
+ " lr: float = 8e-4,\n",
+ " max_norm: float = 1.0,\n",
+ " training_steps: int = 10,\n",
+ " warmup_steps: int = 2,\n",
+ " batch_size: int = 8,\n",
+ " seq_len: int = 2048,\n",
+ " model_name: str = \"llama3\",\n",
+ " flavor: str = \"debugmodel\",\n",
+ " norm_type: str = \"rmsnorm\",\n",
+ " enable_cpu_offload: bool = False,\n",
+ " param_dtype: str = \"float32\",\n",
+ " reduce_dtype: str = \"float32\",\n",
+ " reshard_after_forward: bool = True,\n",
+ " reshard_after_forward_degree: int | None = None,\n",
+ " device_type: str = \"cuda\",\n",
+ " log_freq: int = 1,\n",
+ " gc_freq: int = 50,\n",
+ " profile_freq: int = 10,\n",
+ " profile_active: int = 1,\n",
+ " profile_warmup: int = 3,\n",
+ " dump_folder: str = \".\",\n",
+ " save_traces_folder: str = \"profile_trace\",\n",
+ " save_memory_snapshot_folder: str = \"memory_snapshot\",\n",
+ " apply_compile: bool = False,\n",
+ " num_gas_steps: int = 1,\n",
+ " reshard_after_backward: bool = True,\n",
+ " reduce_grads: bool = True,\n",
+ "):\n",
+ " decay_steps = training_steps - warmup_steps\n",
+ " param_dtype = getattr(torch, param_dtype)\n",
+ " reduce_dtype = getattr(torch, reduce_dtype)\n",
+ " if reshard_after_forward_degree is not None:\n",
+ " assert reshard_after_forward\n",
+ " reshard_after_forward = reshard_after_forward_degree\n",
+ "\n",
+ " init_logger()\n",
+ "\n",
+ " # take control of garbage collection to avoid stragglers\n",
+ " gc_handler = utils.GarbageCollection(gc_freq=gc_freq)\n",
+ "\n",
+ " # init distributed\n",
+ " world_size = int(os.environ[\"WORLD_SIZE\"])\n",
+ " device = torch.device(f\"{device_type}:{int(os.environ['LOCAL_RANK'])}\")\n",
+ " torch.cuda.set_device(device)\n",
+ " if not torch.distributed.is_initialized():\n",
+ " torch.distributed.init_process_group(\"cuda:nccl,cpu:gloo\")\n",
+ " # initialize device memory monitor and get peak flops for MFU calculation\n",
+ " device_memory_monitor = build_device_memory_monitor()\n",
+ " gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)\n",
+ " logger.info(f\"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}\")\n",
+ "\n",
+ " # build meshes\n",
+ " world_mesh = init_device_mesh(device_type, (world_size,), mesh_dim_names=(\"dp\",))\n",
+ " dp_mesh = world_mesh[\"dp\"]\n",
+ " dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()\n",
+ "\n",
+ " # build tokenizer\n",
+ " tokenizer_type = model_name_to_tokenizer[model_name]\n",
+ " tokenizer = build_tokenizer(\n",
+ " tokenizer_type, \"torchtitan/tests/assets/test_tiktoken.model\"\n",
+ " )\n",
+ " # build dataloader\n",
+ " data_loader = build_hf_data_loader(\n",
+ " \"c4_test\",\n",
+ " \"torchtitan/tests/assets/c4_test\",\n",
+ " tokenizer,\n",
+ " batch_size=batch_size,\n",
+ " seq_len=seq_len,\n",
+ " world_size=dp_degree,\n",
+ " rank=dp_rank,\n",
+ " )\n",
+ "\n",
+ " # build model (using meta init)\n",
+ " model_cls = model_name_to_cls[model_name]\n",
+ " model_config = models_config[model_name][flavor]\n",
+ " model_config.norm_type = norm_type\n",
+ " model_config.vocab_size = tokenizer.n_words\n",
+ " model_config.max_seq_len = seq_len\n",
+ "\n",
+ " logger.info(f\"Building {model_name} {flavor} with {model_config}\")\n",
+ " memory_profiler = MemoryProfiler(\n",
+ " profile_freq - 2,\n",
+ " profile_freq,\n",
+ " snapshot_dir=os.path.join(dump_folder, save_memory_snapshot_folder),\n",
+ " dir_name=\"model_init\",\n",
+ " )\n",
+ " memory_profiler.step()\n",
+ " with torch.device(\"meta\"):\n",
+ " model = model_cls.from_model_args(model_config)\n",
+ "\n",
+ " # log model size\n",
+ " model_param_count = utils.get_num_params(model)\n",
+ " num_flop_per_token = utils.get_num_flop_per_token(\n",
+ " utils.get_num_params(model, exclude_embedding=True),\n",
+ " model_config,\n",
+ " seq_len,\n",
+ " )\n",
+ " logger.info(\n",
+ " f\"Model {model_name} {flavor} \" f\"size: {model_param_count:,} total parameters\"\n",
+ " )\n",
+ "\n",
+ " # loss function\n",
+ " def loss_fn(pred, labels):\n",
+ " return torch.nn.functional.cross_entropy(\n",
+ " pred.flatten(0, 1).float(), labels.flatten(0, 1)\n",
+ " )\n",
+ "\n",
+ " # move sharded model to CPU/GPU and initialize weights via DTensor\n",
+ " if enable_cpu_offload:\n",
+ " init_device = \"cpu\"\n",
+ " buffer_device = device_type\n",
+ " else:\n",
+ " init_device = device_type\n",
+ " buffer_device = None\n",
+ "\n",
+ " # apply parallelisms and initialization\n",
+ " if apply_compile:\n",
+ " for layer_id, transformer_block in model.layers.named_children():\n",
+ " transformer_block = torch.compile(transformer_block, fullgraph=True)\n",
+ " model.layers.register_module(layer_id, transformer_block)\n",
+ " logger.info(\"Compiling each TransformerBlock with torch.compile\")\n",
+ " apply_fsdp(\n",
+ " model,\n",
+ " dp_mesh=dp_mesh,\n",
+ " param_dtype=param_dtype,\n",
+ " reduce_dtype=reduce_dtype,\n",
+ " cpu_offload=enable_cpu_offload,\n",
+ " reshard_after_forward=reshard_after_forward,\n",
+ " )\n",
+ " model.to_empty(device=init_device)\n",
+ " with torch.no_grad():\n",
+ " model.init_weights(buffer_device=buffer_device)\n",
+ " model.train()\n",
+ "\n",
+ " memory_profiler.step()\n",
+ "\n",
+ " device_mem_stats = device_memory_monitor.get_peak_stats()\n",
+ " logger.info(\n",
+ " f\"{device_type.upper()} memory usage for model: \"\n",
+ " f\"{device_mem_stats.max_reserved_gib:.2f}GiB\"\n",
+ " f\"({device_mem_stats.max_reserved_pct:.2f}%)\"\n",
+ " )\n",
+ "\n",
+ " optimizer = torch.optim.AdamW(\n",
+ " model.parameters(),\n",
+ " lr=lr,\n",
+ " betas=(0.9, 0.95),\n",
+ " weight_decay=0.1,\n",
+ " fused=True,\n",
+ " )\n",
+ " lr_scheduler = LambdaLR(\n",
+ " optimizer,\n",
+ " lr_lambda=functools.partial(\n",
+ " linear_warmup_linear_decay, warmup_steps, decay_steps\n",
+ " ),\n",
+ " )\n",
+ "\n",
+ " data_iterator = iter(data_loader)\n",
+ "\n",
+ " train_context = utils.get_train_context(\n",
+ " enable_loss_parallel=False,\n",
+ " enable_compiled_autograd=False,\n",
+ " )\n",
+ "\n",
+ " # variables used to keep info for metrics logging\n",
+ " step = 0\n",
+ " ntokens_since_last_log = 0\n",
+ " data_loading_times = []\n",
+ " time_last_log = time.perf_counter()\n",
+ " device_memory_monitor.reset_peak_stats()\n",
+ "\n",
+ " # train loop\n",
+ " logger.info(\n",
+ " f\"Training starts at step {step + 1}, \"\n",
+ " f\"with local batch size {batch_size}, \"\n",
+ " f\"global batch size {batch_size * dp_degree}, \"\n",
+ " f\"sequence length {seq_len}, \"\n",
+ " f\"total steps {training_steps} \"\n",
+ " f\"(warmup {warmup_steps})\"\n",
+ " )\n",
+ " with torch.profiler.profile(\n",
+ " activities=[\n",
+ " torch.profiler.ProfilerActivity.CPU,\n",
+ " torch.profiler.ProfilerActivity.CUDA,\n",
+ " ],\n",
+ " schedule=torch.profiler.schedule(\n",
+ " wait=profile_freq - (profile_active + profile_warmup),\n",
+ " warmup=profile_warmup,\n",
+ " active=profile_active,\n",
+ " ),\n",
+ " on_trace_ready=functools.partial(\n",
+ " trace_handler, trace_dir=os.path.join(dump_folder, save_traces_folder)\n",
+ " ),\n",
+ " record_shapes=True,\n",
+ " ) as torch_profiler:\n",
+ " while step < training_steps:\n",
+ " memory_profiler = MemoryProfiler(\n",
+ " step,\n",
+ " profile_freq,\n",
+ " snapshot_dir=os.path.join(dump_folder, save_memory_snapshot_folder),\n",
+ " )\n",
+ "\n",
+ " step += 1\n",
+ " gc_handler.run(step)\n",
+ "\n",
+ " optimizer.zero_grad()\n",
+ "\n",
+ " for gas_step in range(num_gas_steps):\n",
+ " is_last_backward = gas_step == num_gas_steps - 1\n",
+ " model.set_is_last_backward(is_last_backward)\n",
+ " model.set_reshard_after_backward(\n",
+ " reshard_after_backward or is_last_backward\n",
+ " )\n",
+ " model.set_requires_gradient_sync(reduce_grads or is_last_backward)\n",
+ "\n",
+ " # get batch\n",
+ " data_load_start = time.perf_counter()\n",
+ " batch = next(data_iterator)\n",
+ " input_ids, labels = batch\n",
+ " ntokens_since_last_log += labels.numel()\n",
+ " data_loading_times.append(time.perf_counter() - data_load_start)\n",
+ "\n",
+ " input_ids = input_ids.to(device_type)\n",
+ " labels = labels.to(device_type)\n",
+ "\n",
+ " # Non-PP forward / backward\n",
+ " with train_context():\n",
+ " pred = model(input_ids)\n",
+ " loss = loss_fn(pred, labels)\n",
+ " # pred.shape=(bs, seq_len, vocab_size)\n",
+ " # need to free to before bwd to avoid peaking memory\n",
+ " del pred\n",
+ " loss.backward()\n",
+ "\n",
+ " # clip gradients\n",
+ " torch.nn.utils.clip_grad_norm_([p for p in model.parameters()], max_norm)\n",
+ "\n",
+ " # optimizer step\n",
+ " optimizer.step()\n",
+ " lr_scheduler.step()\n",
+ "\n",
+ " # log metrics\n",
+ " if step == 1 or step % log_freq == 0:\n",
+ " loss = loss.detach()\n",
+ " global_avg_loss = utils.dist_mean(loss, dp_mesh)\n",
+ "\n",
+ " time_delta = time.perf_counter() - time_last_log\n",
+ "\n",
+ " # tokens per second per device, abbreviated as tps\n",
+ " tps = ntokens_since_last_log / time_delta\n",
+ " # model FLOPS utilization\n",
+ " # For its definition and calculation, please refer to the PaLM paper:\n",
+ " # https://arxiv.org/abs/2204.02311\n",
+ " mfu = 100 * num_flop_per_token * tps / gpu_peak_flops\n",
+ "\n",
+ " device_mem_stats = device_memory_monitor.get_peak_stats()\n",
+ "\n",
+ " logger.info(\n",
+ " f\"step: {step:2} \"\n",
+ " f\"loss: {global_avg_loss:7.4f} \"\n",
+ " f\"memory: {device_mem_stats.max_reserved_gib:5.2f}GiB\"\n",
+ " f\"({device_mem_stats.max_reserved_pct:.2f}%) \"\n",
+ " f\"tps: {round(tps):,} \"\n",
+ " f\"mfu: {mfu:.2f}%\"\n",
+ " )\n",
+ "\n",
+ " ntokens_since_last_log = 0\n",
+ " data_loading_times.clear()\n",
+ " time_last_log = time.perf_counter()\n",
+ " device_memory_monitor.reset_peak_stats()\n",
+ "\n",
+ " # signal the profiler that the next profiling step has started\n",
+ " if torch_profiler:\n",
+ " torch_profiler.step()\n",
+ " if memory_profiler:\n",
+ " memory_profiler.step()\n",
+ "\n",
+ " logger.info(\"Training completed\")\n",
+ "\n",
+ " torch.distributed.destroy_process_group()\n",
+ "\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " fire.Fire(train)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!OMP_NUM_THREADS=1 \\\n",
+ " torchrun \\\n",
+ " --local-ranks-filter 0 \\\n",
+ " --nproc-per-node 2 \\\n",
+ " train.py \\\n",
+ " --flavor 1B \\\n",
+ " --batch-size 2 \\\n",
+ " --seq-len 1024 \\\n",
+ " --training-steps 20 \\\n",
+ " --warmup-steps 5 \\\n",
+ " --gc-freq 5 \\\n",
+ " --profile-freq 10 \\\n",
+ " \\\n",
+ " --param-dtype float16 \\\n",
+ " --reduce-dtype float16"
+ ]
+ }
+ ],
+ "metadata": {
+ "kaggle": {
+ "accelerator": "none",
+ "dataSources": [],
+ "dockerImageVersionId": 30887,
+ "isGpuEnabled": false,
+ "isInternetEnabled": true,
+ "language": "python",
+ "sourceType": "notebook"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/week06_fsdp/slides/.gitignore b/week06_fsdp/slides/.gitignore
new file mode 100644
index 0000000..ee89780
--- /dev/null
+++ b/week06_fsdp/slides/.gitignore
@@ -0,0 +1,2 @@
+node_modules
+pnpm-lock.yaml
diff --git a/week06_fsdp/slides/assets/1b_mem_snap.png b/week06_fsdp/slides/assets/1b_mem_snap.png
new file mode 100644
index 0000000..aad630e
Binary files /dev/null and b/week06_fsdp/slides/assets/1b_mem_snap.png differ
diff --git a/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png b/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png
new file mode 100644
index 0000000..6811078
Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png differ
diff --git a/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_trace.png b/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_trace.png
new file mode 100644
index 0000000..03c59ec
Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reduce_grads_no_reshard_after_backward_no_reshard_after_forward_trace.png differ
diff --git a/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png b/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png
new file mode 100644
index 0000000..40a8044
Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_mem_snap.png differ
diff --git a/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_trace.png b/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_trace.png
new file mode 100644
index 0000000..8074c1e
Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reshard_after_backward_no_reshard_after_forward_trace.png differ
diff --git a/week06_fsdp/slides/assets/1b_no_reshard_after_forward_mem_snap.png b/week06_fsdp/slides/assets/1b_no_reshard_after_forward_mem_snap.png
new file mode 100644
index 0000000..af476cd
Binary files /dev/null and b/week06_fsdp/slides/assets/1b_no_reshard_after_forward_mem_snap.png differ
diff --git a/week06_fsdp/slides/assets/1b_reshard_after_forward_4_mem_snap.png b/week06_fsdp/slides/assets/1b_reshard_after_forward_4_mem_snap.png
new file mode 100644
index 0000000..9d402dd
Binary files /dev/null and b/week06_fsdp/slides/assets/1b_reshard_after_forward_4_mem_snap.png differ
diff --git a/week06_fsdp/slides/assets/1b_reshard_after_forward_4_trace.png b/week06_fsdp/slides/assets/1b_reshard_after_forward_4_trace.png
new file mode 100644
index 0000000..bb086b4
Binary files /dev/null and b/week06_fsdp/slides/assets/1b_reshard_after_forward_4_trace.png differ
diff --git a/week06_fsdp/slides/assets/8b_compile_backward_trace.png b/week06_fsdp/slides/assets/8b_compile_backward_trace.png
new file mode 100644
index 0000000..eea5903
Binary files /dev/null and b/week06_fsdp/slides/assets/8b_compile_backward_trace.png differ
diff --git a/week06_fsdp/slides/assets/8b_compile_forward_trace.png b/week06_fsdp/slides/assets/8b_compile_forward_trace.png
new file mode 100644
index 0000000..7ee9178
Binary files /dev/null and b/week06_fsdp/slides/assets/8b_compile_forward_trace.png differ
diff --git a/week06_fsdp/slides/assets/8b_compile_trace.png b/week06_fsdp/slides/assets/8b_compile_trace.png
new file mode 100644
index 0000000..a152cf2
Binary files /dev/null and b/week06_fsdp/slides/assets/8b_compile_trace.png differ
diff --git a/week06_fsdp/slides/assets/8b_cpu_offload_mem_snap.png b/week06_fsdp/slides/assets/8b_cpu_offload_mem_snap.png
new file mode 100644
index 0000000..dd34750
Binary files /dev/null and b/week06_fsdp/slides/assets/8b_cpu_offload_mem_snap.png differ
diff --git a/week06_fsdp/slides/assets/8b_no_compile_iter_mem_snap.png b/week06_fsdp/slides/assets/8b_no_compile_iter_mem_snap.png
new file mode 100644
index 0000000..943d70b
Binary files /dev/null and b/week06_fsdp/slides/assets/8b_no_compile_iter_mem_snap.png differ
diff --git a/week06_fsdp/slides/assets/8b_on_cuda_model_mem_snap.png b/week06_fsdp/slides/assets/8b_on_cuda_model_mem_snap.png
new file mode 100644
index 0000000..03e0455
Binary files /dev/null and b/week06_fsdp/slides/assets/8b_on_cuda_model_mem_snap.png differ
diff --git a/week06_fsdp/slides/assets/8b_reshard_after_forward_4_trace.png b/week06_fsdp/slides/assets/8b_reshard_after_forward_4_trace.png
new file mode 100644
index 0000000..956a461
Binary files /dev/null and b/week06_fsdp/slides/assets/8b_reshard_after_forward_4_trace.png differ
diff --git a/week06_fsdp/slides/assets/dcp_1.png b/week06_fsdp/slides/assets/dcp_1.png
new file mode 100644
index 0000000..b9f741d
Binary files /dev/null and b/week06_fsdp/slides/assets/dcp_1.png differ
diff --git a/week06_fsdp/slides/assets/dcp_2.png b/week06_fsdp/slides/assets/dcp_2.png
new file mode 100644
index 0000000..154828b
Binary files /dev/null and b/week06_fsdp/slides/assets/dcp_2.png differ
diff --git a/week06_fsdp/slides/assets/dcp_3.png b/week06_fsdp/slides/assets/dcp_3.png
new file mode 100644
index 0000000..fb57a81
Binary files /dev/null and b/week06_fsdp/slides/assets/dcp_3.png differ
diff --git a/week06_fsdp/slides/assets/dcp_saving_flow.png b/week06_fsdp/slides/assets/dcp_saving_flow.png
new file mode 100644
index 0000000..30dd209
Binary files /dev/null and b/week06_fsdp/slides/assets/dcp_saving_flow.png differ
diff --git a/week06_fsdp/slides/assets/device_mesh.png b/week06_fsdp/slides/assets/device_mesh.png
new file mode 100644
index 0000000..2ccabcc
Binary files /dev/null and b/week06_fsdp/slides/assets/device_mesh.png differ
diff --git a/week06_fsdp/slides/assets/dtensor_1.png b/week06_fsdp/slides/assets/dtensor_1.png
new file mode 100644
index 0000000..e20d514
Binary files /dev/null and b/week06_fsdp/slides/assets/dtensor_1.png differ
diff --git a/week06_fsdp/slides/assets/dtensor_2.png b/week06_fsdp/slides/assets/dtensor_2.png
new file mode 100644
index 0000000..0813b8a
Binary files /dev/null and b/week06_fsdp/slides/assets/dtensor_2.png differ
diff --git a/week06_fsdp/slides/assets/dtensor_3.png b/week06_fsdp/slides/assets/dtensor_3.png
new file mode 100644
index 0000000..87f34c5
Binary files /dev/null and b/week06_fsdp/slides/assets/dtensor_3.png differ
diff --git a/week06_fsdp/slides/assets/forward_hook.png b/week06_fsdp/slides/assets/forward_hook.png
new file mode 100644
index 0000000..8704b82
Binary files /dev/null and b/week06_fsdp/slides/assets/forward_hook.png differ
diff --git a/week06_fsdp/slides/assets/forward_pre_hook.png b/week06_fsdp/slides/assets/forward_pre_hook.png
new file mode 100644
index 0000000..c177b28
Binary files /dev/null and b/week06_fsdp/slides/assets/forward_pre_hook.png differ
diff --git a/week06_fsdp/slides/assets/fsdp_workflow.png b/week06_fsdp/slides/assets/fsdp_workflow.png
new file mode 100644
index 0000000..1a8df0e
Binary files /dev/null and b/week06_fsdp/slides/assets/fsdp_workflow.png differ
diff --git a/week06_fsdp/slides/assets/fsdp_wrap.png b/week06_fsdp/slides/assets/fsdp_wrap.png
new file mode 100644
index 0000000..1c1d05a
Binary files /dev/null and b/week06_fsdp/slides/assets/fsdp_wrap.png differ
diff --git a/week06_fsdp/slides/assets/streams.png b/week06_fsdp/slides/assets/streams.png
new file mode 100644
index 0000000..f49a04d
Binary files /dev/null and b/week06_fsdp/slides/assets/streams.png differ
diff --git a/week06_fsdp/slides/package.json b/week06_fsdp/slides/package.json
new file mode 100644
index 0000000..4f6912c
--- /dev/null
+++ b/week06_fsdp/slides/package.json
@@ -0,0 +1,7 @@
+{
+ "dependencies": {
+ "@slidev/cli": "^51.3.0",
+ "@slidev/theme-default": "^0.25.0",
+ "playwright-chromium": "^1.50.1"
+ }
+}
diff --git a/week06_fsdp/slides/slides.md b/week06_fsdp/slides/slides.md
new file mode 100644
index 0000000..cd9cb8d
--- /dev/null
+++ b/week06_fsdp/slides/slides.md
@@ -0,0 +1,605 @@
+---
+title: FSDP семинар
+transition: slide-left
+---
+
+# FSDP семинар
+
+---
+
+# Plan
+
+
+
+
+
+- Prerequisites: CUDA streams / events, DeviceMesh, DTensor
+- FSDP2: interface, options, internals
+- PyTorch DCP, efficient garbage collection
+
+---
+
+# CUDA streams and events
+
+```python
+all_gather_stream = torch.cuda.Stream()
+
+...
+
+# layer 3 unshard
+with torch.cuda.stream(all_gather_stream):
+ model.layers[3].all_gather()
+ all_gather_event_3 = torch.cuda.Event()
+ # or all_gather_stream.record_event()
+
+# layer 2 forward
+activations = model.layers[2](activations)
+
+# layer 4 unshard
+with torch.cuda.stream(all_gather_stream):
+ model.layers[4].all_gather()
+ all_gather_event_4 = torch.cuda.Event()
+
+# layer 3 forward
+torch.cuda.default_stream().wait_event(all_gather_event_3)
+activations = model.block[3](activations)
+
+...
+
+```
+
+---
+
+# CUDA streams and events
+
+
+
+
+
+
+
+
+
+---
+
+# DeviceMesh
+
+
+
+---
+
+# DeviceMesh
+
+```python
+from torch.distributed.device_mesh import init_device_mesh
+
+mesh_1d = init_device_mesh("cuda", mesh_shape=(8,), mesh_dim_names=("dp",))
+mesh_2d = init_device_mesh("cpu", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
+mesh_3d = init_device_mesh(
+ "cuda",
+ mesh_shape=(2, 2, 8),
+ mesh_dim_names=("pp", "dp", "tp"),
+)
+
+dp_group = mesh_2d.get_group("dp")
+dist.all_gather(..., group=dp_group)
+
+mesh_2d.get_local_rank("tp")
+
+mesh_3d["dp", "tp"]._flatten("dp_tp")
+
+```
+
+---
+
+# DTensor
+
+```python
+from torch.distributed.tensor import DTensor, distribute_tensor
+
+mesh = init_device_mesh("cuda", mesh_shape=(8,), mesh_dim_names=("dp",))
+big_tensor = torch.randn(1024, 4096)
+placements = (Shard(dim=0),)
+
+dtensor = distribute_tensor(
+ big_tensor,
+ device_mesh=mesh,
+ placements=placements,
+)
+dtensor._local_tensor
+dtensor.to_local() # .shape = (512, 4096)
+
+
+shard = ... # .shape = (512, 4096)
+DTensor.from_local(
+ shard,
+ device_mesh=mesh,
+ placements=placements,
+) # .shape = (1024, 4096)
+
+dtensor.redistribute(placements=(Replicate(),))
+dtensor.full_tensor()
+```
+
+---
+
+# DTensor
+
+
+
+
+
+---
+
+# DTensor
+
+
+
+---
+
+# FSDP2
+
+
+
+---
+layout: two-cols-header
+layoutClass: gap-5
+---
+
+# FSDP2
+
+::left::
+
+```python
+from torch.distributed.fsdp import fully_shard
+
+mesh_2d = init_device_mesh(
+ "cuda",
+ mesh_shape=(2, 8),
+ mesh_dim_names=("dp", "tp"),
+)
+model = Model()
+
+for layer in model.layers:
+ fully_shard(
+ module, # (module1, module2)
+ mesh=dp_mesh,
+ reshard_after_forward=True, # ZeRO-3
+ mp_policy=MixedPrecisionPolicy(
+ param_dtype=torch.float16,
+ reduce_dtype=torch.float32,
+ ),
+ offload_policy=CPUOffloadPolicy(),
+ )
+
+fully_shard(model, ...)
+```
+
+::right::
+
+```python
+for step in ...:
+ for gas_step in ...:
+ is_last_backward = gas_step == num_gas_steps - 1
+ # ZeRO-3
+ model.set_reshard_after_backward(is_last_backward)
+ # ZeRO-2
+ model.requires_gradient_sync(is_last_backward)
+
+ loss = loss_fn(model(inputs), targets)
+ ...
+```
+
+---
+
+# FSDP2
+
+
+
+---
+
+# FSDP2 — hooks
+
+
+
+
+
+---
+
+# FSDP2 — pre-forward
+
+```python
+def pre_forward(module, args):
+ module.unshard() # in all-gather stream
+ module.wait_for_unshard() # sync compute (default) stream with all-gather stream
+ module._register_post_backward_hook(args)
+ return args
+
+def unshard(module):
+ with torch.cuda.stream(all_gather_stream):
+ module.all_gather()
+ module.all_gather_event = all_gather_stream.record_event()
+ module.set_unsharded_params()
+
+def wait_for_unshard(module):
+ torch.cuda.default_stream().wait_event(module.all_gather_event)
+
+def fully_shard(module, ...):
+ ...
+ module.register_forward_pre_hook(pre_forward)
+```
+
+
+
+---
+
+# FSDP2 — post-forward
+
+```python
+def post_forward(module, args, output):
+ module.reshard()
+ module._record_post_forward()
+ module._register_pre_backward_hook(output)
+ return output
+
+def reshard(module):
+ module.set_sharded_params() # and free unsharded params
+
+def _record_post_forward(module):
+ post_forward_index = len(module.comm_ctx.post_forward_order)
+ module.comm_ctx.post_forward_order.append(module)
+ module._post_forward_indices.append(post_forward_index)
+
+def fully_shard(module, ...):
+ ...
+ module.register_forward_hook(post_forward)
+```
+
+---
+
+# FSDP2 — pre-backward
+
+```python
+def pre_backward(module, *unused):
+ module.unshard() # no-op if prefetched
+ module.wait_for_unshard()
+ module._backward_prefetch()
+
+def _backward_prefetch(module):
+ curr_index = module._post_forward_indices.pop()
+ target_index = curr_index - 1
+ target_module = self.comm_ctx.post_forward_order[target_index]
+ target_module.unshard()
+
+def _register_pre_backward_hook(self, output):
+ for t in output:
+ if torch.is_tensor(t) and t.requires_grad:
+ t.register_hook(self._pre_backward)
+ return output
+```
+
+---
+
+# FSDP2 — post-backward
+
+```python
+def post_backward(module, *unused: Any):
+ if module.reshard_after_backward:
+ module.reshard()
+ if module.reduce_grads:
+ reduce_scatter_stream.wait_stream(torch.cuda.default_stream())
+ with torch.cuda.stream(reduce_scatter_stream):
+ module.reduce_scatter_grads()
+ reduce_event = reduce_scatter_stream.record_event()
+
+def _register_post_backward_hook(module, args):
+ RegisterPostBackwardFunction.apply(self, *args)
+
+class RegisterPostBackwardFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, module, *inputs):
+ ctx.module = module
+ return inputs
+
+ @staticmethod
+ def backward(ctx, *grads):
+ module.post_backward()
+ return (None,) + grads
+```
+
+---
+
+# FSDP2 — memory
+
+
+
+---
+
+# FSDP2 — memory
+
+
+
+---
+
+# Computation / communication overlap
+
+- Implicit prefetching
+ - в `pre_forward`
+- Explicit prefetching
+ - в `pre_backward`
+ - можно задать руками
+
+ ```python
+ module.set_modules_to_forward_prefetch(modules)
+ module.set_modules_to_backward_prefetch(modules)
+ ```
+
+---
+
+# Подробнее про работу со стримами
+
+
+
+
+
+
+
+
+
+
+---
+
+# Подробнее про работу со стримами — forward
+
+
+
+
+
+
+
+
+
+---
+
+# Подробнее про работу со стримами — backward
+
+
+
+
+
+
+
+
+
+
+---
+
+# ZeRO-2
+
+
+
+
+
+
+
+
+
+
+---
+
+# ZeRO-2
+
+
+
+---
+
+# ZeRO-1
+
+
+
+
+
+
+
+
+
+
+---
+
+# ZeRO-1
+
+
+
+---
+
+# HSDP
+
+```python
+mesh_2d = init_device_mesh(
+ "cpu",
+ mesh_shape=(2, 8),
+ mesh_dim_names=("dp_replicate", "dp_shard"),
+)
+
+fully_shard(
+ module,
+ mesh=mesh_2d,
+ ...
+)
+```
+
+
+
+- логика становится заметно сложнее, показывать не буду(
+
+---
+
+# CPU offloading
+
+- [ZeRO-Offload](https://arxiv.org/pdf/2101.06840)
+
+```python
+with torch.device("cpu"):
+ model = Model()
+
+fully_shard(
+ module,
+ ...
+ offload_policy=CPUOffloadPolicy(),
+)
+
+def unshard(module):
+ sharded_param = sharded_param.to(
+ device,
+ non_blocking=True,
+ )
+ ...
+ module.all_gather()
+
+def post_backward(module):
+ new_sharded_grad = new_sharded_grad.to(
+ torch.device("cpu"),
+ non_blocking=True
+ )
+```
+
+---
+
+# CPU offloading
+
+
+
+---
+
+# hpZ
+
+- [ZeRO++](https://arxiv.org/pdf/2306.10209)
+
+```python
+mesh = init_device_mesh(
+ "cuda",
+ mesh_shape=(16,),
+ mesh_dim_names=("dp",),
+)
+fully_shard(
+ module,
+ mesh,
+ ...
+ reshard_after_forward=8,
+)
+```
+
+---
+
+# hpZ
+
+
+
+---
+
+# hpZ
+
+
+
+---
+
+# PyTorch DCP
+
+- два вида `state_dict`
+ - `SHARDED_STATE_DICT`
+ - `FULL_STATE_DICT`
+- в FSDP2 всегда sharded, но состоит из DTensor-ов
+ - с помощью `.redistribute()` можно менять шардирование чекпоинта
+- DCP умеет эффективно отгружать чекпоинты с минимальным оверхедом
+
+---
+
+# PyTorch DCP
+
+```python
+import torch.distributed.checkpoint as dcp
+model = Model()
+fully_shard(model)
+optimizer = Optimizer(model.parameters())
+
+state_dict = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict()
+}
+dcp.state_dict_saver.save(state_dict)
+dcp.state_dict_loader.load(state_dict)
+```
+
+
+
+- [truthfully i's a bit more complicated](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py)
+
+---
+
+# PyTorch DCP
+
+
+
+---
+
+# PyTorch DCP
+
+
+
+
+
+
+---
+
+# PyTorch DCP
+
+
+
+
+
+
+---
+
+# PyTorch DCP
+
+
+
+
+
+
+---
+
+# Garbage collection tuning
+
+```python
+gc.disable()
+gc.collect(1)
+
+... init
+
+for step in ...:
+ if step > 1 and step % _gc_freq == 0:
+ gc.collect(1)
+
+ ... step
+```
+
+---
+
+# Extras
+
+- [SimpleFSDP](https://arxiv.org/pdf/2411.00284)
+- `unshard_in_backward`
+- meta device init
+- compile
+
+---
+
+# Code
+
+- можно поиграться со всем этим в [ноутбуке](https://www.kaggle.com/code/antonyfrolov/practice-ipynb)
+- пайплайн отладки
diff --git a/week06_fsdp/slides/slides.pdf b/week06_fsdp/slides/slides.pdf
new file mode 100644
index 0000000..ea77015
Binary files /dev/null and b/week06_fsdp/slides/slides.pdf differ