From 14736b862c29de0b86325f0ff83100e3a1821fef Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 22 Apr 2025 06:35:48 -0700 Subject: [PATCH 1/7] compile optimizer [ghstack-poisoned] --- recipes/full_finetune_distributed.py | 10 +++++++++- torchtune/training/__init__.py | 3 ++- torchtune/training/_compile.py | 6 ++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index df96c1ff8d..2c94ae5f2d 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -923,7 +923,15 @@ def train(self) -> None: # If sharded, collect the DTensor here if isinstance(grad_norm, DTensor): grad_norm = grad_norm.full_tensor() - self._optimizer.step() + optimizer_step_fn = self._optimizer.step + if self._compile: + def _fn(): + self._optimizer.step() + optimizer_step_fn = training.compile_optimizer_step( + _fn, + verbose=self._is_rank_zero + ) + optimizer_step_fn() self._optimizer.zero_grad(set_to_none=True) # Update the number of steps when the weights are updated diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index b2c327c617..e19d4faf3b 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -8,7 +8,7 @@ NoOpManager, OffloadActivations, ) -from torchtune.training._compile import compile_loss, compile_model +from torchtune.training._compile import compile_loss, compile_model, compile_optimizer_step from torchtune.training._distributed import ( gather_cpu_state_dict, get_distributed_backend, @@ -135,6 +135,7 @@ "setup_torch_profiler", "compile_loss", "compile_model", + "compile_optimizer_step", "NoOpManager", "OffloadActivations", "FormattedCheckpointFiles", diff --git a/torchtune/training/_compile.py b/torchtune/training/_compile.py index ca02a3e7d6..558af324fb 100644 --- a/torchtune/training/_compile.py +++ b/torchtune/training/_compile.py @@ -86,3 +86,9 @@ def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module: else: loss = torch.compile(loss, backend=backend) return loss + +def compile_optimizer_step(optimizer_step_fn, verbose: bool = True): + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if verbose: + log.info("Compiling optimizer step function with torch.compile...") + return torch.compile(optimizer_step_fn, backend=backend) From 4023ed5ce2a77cd396e3a37f4a953b5bb431fab2 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 22 Apr 2025 06:51:16 -0700 Subject: [PATCH 2/7] Update on "compile optimizer" Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned] --- recipes/full_finetune_distributed.py | 5 +++-- torchtune/training/__init__.py | 6 +++++- torchtune/training/_compile.py | 1 + 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 2c94ae5f2d..527e80fded 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -925,11 +925,12 @@ def train(self) -> None: grad_norm = grad_norm.full_tensor() optimizer_step_fn = self._optimizer.step if self._compile: + def _fn(): self._optimizer.step() + optimizer_step_fn = training.compile_optimizer_step( - _fn, - verbose=self._is_rank_zero + _fn, verbose=self._is_rank_zero ) optimizer_step_fn() self._optimizer.zero_grad(set_to_none=True) diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index e19d4faf3b..cf9300e41b 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -8,7 +8,11 @@ NoOpManager, OffloadActivations, ) -from torchtune.training._compile import compile_loss, compile_model, compile_optimizer_step +from torchtune.training._compile import ( + compile_loss, + compile_model, + compile_optimizer_step, +) from torchtune.training._distributed import ( gather_cpu_state_dict, get_distributed_backend, diff --git a/torchtune/training/_compile.py b/torchtune/training/_compile.py index 558af324fb..65ab7d7631 100644 --- a/torchtune/training/_compile.py +++ b/torchtune/training/_compile.py @@ -87,6 +87,7 @@ def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module: loss = torch.compile(loss, backend=backend) return loss + def compile_optimizer_step(optimizer_step_fn, verbose: bool = True): backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") if verbose: From e12e17fe3cef53cd644cca443dd9d2634622e291 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 22 Apr 2025 09:40:18 -0700 Subject: [PATCH 3/7] Update on "compile optimizer" Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned] --- recipes/full_finetune_distributed.py | 11 +++++------ torchtune/training/__init__.py | 7 +------ torchtune/training/_compile.py | 7 ------- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 527e80fded..695aad12e7 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time @@ -306,6 +307,7 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self._checkpoint_client.load_base_checkpoint() self._compile = cfg.get("compile", False) + self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=self._enable_activation_checkpointing, @@ -925,12 +927,9 @@ def train(self) -> None: grad_norm = grad_norm.full_tensor() optimizer_step_fn = self._optimizer.step if self._compile: - - def _fn(): - self._optimizer.step() - - optimizer_step_fn = training.compile_optimizer_step( - _fn, verbose=self._is_rank_zero + optimizer_step_fn = torch.compile( + optimizer_step_fn, + backend=self._compile_backend, ) optimizer_step_fn() self._optimizer.zero_grad(set_to_none=True) diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index cf9300e41b..b2c327c617 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -8,11 +8,7 @@ NoOpManager, OffloadActivations, ) -from torchtune.training._compile import ( - compile_loss, - compile_model, - compile_optimizer_step, -) +from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( gather_cpu_state_dict, get_distributed_backend, @@ -139,7 +135,6 @@ "setup_torch_profiler", "compile_loss", "compile_model", - "compile_optimizer_step", "NoOpManager", "OffloadActivations", "FormattedCheckpointFiles", diff --git a/torchtune/training/_compile.py b/torchtune/training/_compile.py index 65ab7d7631..ca02a3e7d6 100644 --- a/torchtune/training/_compile.py +++ b/torchtune/training/_compile.py @@ -86,10 +86,3 @@ def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module: else: loss = torch.compile(loss, backend=backend) return loss - - -def compile_optimizer_step(optimizer_step_fn, verbose: bool = True): - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - if verbose: - log.info("Compiling optimizer step function with torch.compile...") - return torch.compile(optimizer_step_fn, backend=backend) From 568cdb4bbe7b6bfc517ccf81dd440124e69514c5 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 28 Apr 2025 04:02:22 -0700 Subject: [PATCH 4/7] Update on "compile optimizer" Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned] --- recipes/configs/llama4/scout_17B_16E_full.yaml | 5 +++++ recipes/full_finetune_distributed.py | 17 ++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/recipes/configs/llama4/scout_17B_16E_full.yaml b/recipes/configs/llama4/scout_17B_16E_full.yaml index ce7d765f7b..3f26dc6254 100644 --- a/recipes/configs/llama4/scout_17B_16E_full.yaml +++ b/recipes/configs/llama4/scout_17B_16E_full.yaml @@ -71,6 +71,11 @@ enable_activation_offloading: False fsdp_cpu_offload: True compile: False # torch.compile, set to true for perf/memory improvement +compile_components: + model: True + loss: True + optimizer_step: False + # Reduced precision dtype: bf16 diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 695aad12e7..0fd6ef4f37 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -308,6 +308,17 @@ def setup(self, cfg: DictConfig) -> None: self._compile = cfg.get("compile", False) self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + + self._compile_model = False + self._compile_loss = False + self._compile_optimizer_step = False + if self._compile: + self._compile_model = cfg.get("compile_components.model", True) + self._compile_loss = cfg.get("compile_components.loss", True) + self._compile_optimizer_step = cfg.get( + "compile_components.optimizer_step", False + ) + self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=self._enable_activation_checkpointing, @@ -359,7 +370,7 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if self._compile: + if self._compile_loss: training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": @@ -570,7 +581,7 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) - if self._compile: + if self._compile_model: training.compile_model(model, verbose=self._is_rank_zero) if self._enable_fp8_training: @@ -926,7 +937,7 @@ def train(self) -> None: if isinstance(grad_norm, DTensor): grad_norm = grad_norm.full_tensor() optimizer_step_fn = self._optimizer.step - if self._compile: + if self._compile_optimizer_step: optimizer_step_fn = torch.compile( optimizer_step_fn, backend=self._compile_backend, From 7a6a4d2a91db615bb98917ace0777c21a88aced3 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 28 Apr 2025 06:50:55 -0700 Subject: [PATCH 5/7] Update on "compile optimizer" Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned] --- recipes/full_finetune_distributed.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 0fd6ef4f37..d3190d6d1b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -312,11 +312,12 @@ def setup(self, cfg: DictConfig) -> None: self._compile_model = False self._compile_loss = False self._compile_optimizer_step = False - if self._compile: - self._compile_model = cfg.get("compile_components.model", True) - self._compile_loss = cfg.get("compile_components.loss", True) - self._compile_optimizer_step = cfg.get( - "compile_components.optimizer_step", False + compile_components = cfg.get("compile_components") + if self._compile and compile_components: + self._compile_model = compile_components.get("model", True) + self._compile_loss = compile_components.get("loss", True) + self._compile_optimizer_step = compile_components.get( + "optimizer_step", False ) self._model = self._setup_model( @@ -341,6 +342,11 @@ def setup(self, cfg: DictConfig) -> None: else None ), ) + if self._compile_optimizer_step: + self._optimizer.step = torch.compile( + self._optimizer.step, + backend=self._compile_backend, + ) if self._resume_from_checkpoint: # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously @@ -936,13 +942,7 @@ def train(self) -> None: # If sharded, collect the DTensor here if isinstance(grad_norm, DTensor): grad_norm = grad_norm.full_tensor() - optimizer_step_fn = self._optimizer.step - if self._compile_optimizer_step: - optimizer_step_fn = torch.compile( - optimizer_step_fn, - backend=self._compile_backend, - ) - optimizer_step_fn() + self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) # Update the number of steps when the weights are updated From f2f49bd18e9338d56c55b0c9bb6f7e51bf9f4f7a Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 28 Apr 2025 11:16:16 -0700 Subject: [PATCH 6/7] Update on "compile optimizer" Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned] --- .../configs/llama4/scout_17B_16E_full.yaml | 14 +++++++------ recipes/full_finetune_distributed.py | 20 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/recipes/configs/llama4/scout_17B_16E_full.yaml b/recipes/configs/llama4/scout_17B_16E_full.yaml index 3f26dc6254..4a4715b959 100644 --- a/recipes/configs/llama4/scout_17B_16E_full.yaml +++ b/recipes/configs/llama4/scout_17B_16E_full.yaml @@ -69,12 +69,14 @@ device: cuda enable_activation_checkpointing: True enable_activation_offloading: False fsdp_cpu_offload: True -compile: False # torch.compile, set to true for perf/memory improvement - -compile_components: - model: True - loss: True - optimizer_step: False +# compile True means use torch.compile for all components +# compile False means no torch.compile +# compile Dictionary with keys: "model", "loss", "optimizer_step" +# enables torch.compile only for specified components. +compile: False +# model: True +# loss: True +# optimizer_step: False # Reduced precision dtype: bf16 diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index d3190d6d1b..ab8a8c7905 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -306,19 +306,17 @@ def setup(self, cfg: DictConfig) -> None: # Load the base model checkpoint_dict = self._checkpoint_client.load_base_checkpoint() - self._compile = cfg.get("compile", False) + compile = cfg.get("compile") + compile_bool = bool(compile) self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - self._compile_model = False - self._compile_loss = False - self._compile_optimizer_step = False - compile_components = cfg.get("compile_components") - if self._compile and compile_components: - self._compile_model = compile_components.get("model", True) - self._compile_loss = compile_components.get("loss", True) - self._compile_optimizer_step = compile_components.get( - "optimizer_step", False - ) + self._compile_model = compile_bool + self._compile_loss = compile_bool + self._compile_optimizer_step = compile_bool + if isinstance(compile, dict): + self._compile_model = compile.get("model", True) + self._compile_loss = compile.get("loss", True) + self._compile_optimizer_step = compile.get("optimizer_step", False) self._model = self._setup_model( cfg_model=cfg.model, From d8b94316d339a59f29cc01445d185bb93db87d7f Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 28 Apr 2025 11:59:31 -0700 Subject: [PATCH 7/7] Update on "compile optimizer" Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned] --- recipes/full_finetune_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index ab8a8c7905..cfacbbd148 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -313,7 +313,7 @@ def setup(self, cfg: DictConfig) -> None: self._compile_model = compile_bool self._compile_loss = compile_bool self._compile_optimizer_step = compile_bool - if isinstance(compile, dict): + if isinstance(compile, DictConfig): self._compile_model = compile.get("model", True) self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False)