Skip to content

Conversation

@IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented Apr 22, 2025

Stack from ghstack (oldest at bottom):

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml

PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2623

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit af77178 with merge base 4bc5af2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2025
Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're now compiling several things independently, it might make sense logically to have a section of the recipe where we compile everything after instantiation.

return loss


def compile_optimizer_step(optimizer_step_fn, verbose: bool = True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate you wanting to keep this similar to how we were currently doing things; however, we only needed to this for the loss function b/c we were doing funky things with chunking.

We should just compile this directly in the recipe. Same goes for the other PR you have up.

if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
self._optimizer.step()
optimizer_step_fn = self._optimizer.step
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below, we can just compile the optimizer step in the recipe directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agree, just copied the previous setup. Will move compile to the recipe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: is there a reason we need to compile self._optimizer.step every step? Why is it different than the model, which we compile one time upfront?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tried this out I found issues with setting up the LR scheduler which fails when attempting to wrap the optimizer step fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will check the compile optimizer error with LR scheduler.

@joecummings
Copy link
Member

Ah sorry, did not mean to approve :)

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
@codecov-commenter
Copy link

codecov-commenter commented Apr 22, 2025

Codecov Report

Attention: Patch coverage is 0% with 15 lines in your changes missing coverage. Please review.

Please upload report for BASE (gh/IvanKobzarev/1/base@eed2665). Learn more about missing BASE report.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 15 Missing ⚠️
Additional details and impacted files
@@                    Coverage Diff                    @@
##             gh/IvanKobzarev/1/base    #2623   +/-   ##
=========================================================
  Coverage                          ?   63.97%           
=========================================================
  Files                             ?      399           
  Lines                             ?    24241           
  Branches                          ?        0           
=========================================================
  Hits                              ?    15507           
  Misses                            ?     8734           
  Partials                          ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines 929 to 933
if self._compile:
optimizer_step_fn = torch.compile(
optimizer_step_fn,
backend=self._compile_backend,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some optimizers might not work with this, if i remember it correctly, like torchao/bnb. May need some testing. The safest option might be to add a compile flag per area, e.g.:

compile:
     loss: True
     model: True
     optimizer_step: False
     ```

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
@IvanKobzarev
Copy link
Contributor Author

IvanKobzarev commented Apr 28, 2025

Changed to direct compilation of self.optimizer.step and it works :)
Updated the diff.

Just FYI for testing: compilation at the moment needs workarounds for 2 different problems:

  1. There is some problem with rng states preservation which can be workarounded
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 668353867ab..493883542f9 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -249,8 +249,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
             prior_dtype = torch.get_default_dtype()
             torch_rng_state = torch.random.get_rng_state()
             cuda_rng_state = None
-            if torch.cuda.is_available():
-                cuda_rng_state = torch.cuda.get_rng_state()
+            # if torch.cuda.is_available():
+            #     cuda_rng_state = torch.cuda.get_rng_state()
             allow_tf32 = torch._C._get_cublas_allow_tf32()
             prior_fwd_from_src = torch.fx.graph_module._forward_from_src
             torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
@@ -281,8 +281,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
                 )
                 if prior_mobile_allocator_state != curr_mobile_allocator_state:
                     torch._C._unset_default_mobile_cpu_allocator()
-                if cuda_rng_state is not None:
-                    torch.cuda.set_rng_state(cuda_rng_state)
+                # if cuda_rng_state is not None:
+                #     torch.cuda.set_rng_state(cuda_rng_state)
                 torch._C._set_cublas_allow_tf32(allow_tf32)
                 torch.fx.graph_module._forward_from_src = prior_fwd_from_src
                 assert guards.check(), (
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index b75b1d6c39f..7ca67523704 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -2110,15 +2110,15 @@ def preserve_rng_state():
     with disable_current_modes(), disable_functorch():
         rng_state = torch.clone(torch.random.get_rng_state())
         skip_frame_if_in_functorch_mode(rng_state)
-        if torch.cuda.is_available():
-            cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
+        # if torch.cuda.is_available():
+        #     cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
     try:
         yield
     finally:
         with torch.utils._python_dispatch._disable_current_modes():
             torch.random.set_rng_state(rng_state)
-            if torch.cuda.is_available():
-                torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
+            # if torch.cuda.is_available():
+            #     torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
 
 
 def is_jit_model(model0):
  1. There is illegal memory access in Chunked flex Attention x Caching.

If to remove /tmp/torchinductor_${USER} before every run - then it does not fires (or disable pt2 cache)

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one nit on naming, but this looks good!

fsdp_cpu_offload: True
compile: False # torch.compile, set to true for perf/memory improvement

compile_components:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we match the argument to just "compile"? Then valid arguments would be "True", "False", or the specific components. If "True", then we compile everything. If "False", we compile nothing. If the argument has a dictionary with each component, then we follow those instructions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Agree with this logic, will update to it.

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
@IvanKobzarev IvanKobzarev merged commit 28dbc97 into gh/IvanKobzarev/1/base May 2, 2025
14 checks passed
IvanKobzarev added a commit that referenced this pull request May 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants