Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation dataset loss to distributed SFT recipies #2464

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

bzz
Copy link

@bzz bzz commented Mar 6, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Addresses #1042 / part of the #883 for distributed recipies

Changelog

What are the changes made in this PR?

New configuration:

dataset_validation: null
run_val_every_n_steps: null

max_validation_batches: -1

Test plan

As suggested, I run with compile + opt_in_bwd + activation ckpt + activation offloading on 2xH100 using mhenrichsen/alpaca_2k_test as a validation dataset.

LoRA Distributed

pip install -U git+https://github.com/bzz/torchtune.git@validate_distributed

# with validation
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config 1B_qlora_validation.yaml \
	compile=True \
	enable_activation_checkpointing=True \
	enable_activation_offloading=True \
	optimizer_in_bwd=True \
	tokenizer.max_seq_len=4096 \
	gradient_accumulation_steps=1 \
	epochs=1 \
	batch_size=16 \
	metric_logger._component_=torchtune.training.metric_logging.WandBLogger \
	output_dir=/tmp/torchtune/llama3_2_1B/qlora_val

Without validation

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config 1B_qlora_validation.yaml \
	run_val_every_n_steps=None \
	compile=True \
	enable_activation_checkpointing=True \
	enable_activation_offloading=True \
	optimizer_in_bwd=True \
	tokenizer.max_seq_len=4096 \
	gradient_accumulation_steps=1 \
	epochs=1 \
	batch_size=16 \
	metric_logger._component_=torchtune.training.metric_logging.WandBLogger \
	output_dir=/tmp/torchtune/llama3_2_1B/qlora
Screenshot 2025-03-06 at 08 05 19
llama3_2/1B_qlora_validation.yaml
diff recipes/configs/llama3_2/1B_qlora_validation.yaml recipes/configs/llama3_2/1B_lora.yaml

20c20
< output_dir: /tmp/torchtune/llama3_2_1B/qlora # /tmp may be deleted by your system. Change it to your preference.
---
> output_dir: /tmp/torchtune/llama3_2_1B/lora # /tmp may be deleted by your system. Change it to your preference.
30c30
<   _component_: torchtune.models.llama3_2.qlora_llama3_2_1b
---
>   _component_: torchtune.models.llama3_2.lora_llama3_2_1b
57,65d56
< # Validation
< dataset_validation:
<   _component_: torchtune.datasets.alpaca_dataset
<   source: mhenrichsen/alpaca_2k_test
<   split: train
<   packed: False  # True increases speed
< run_val_every_n_steps: 100
< max_validation_batches: -1
<
82c73
< gradient_accumulation_steps: 1  # Use to increase effective batch size
---
> gradient_accumulation_steps: 8  # Use to increase effective batch size
84c75
< compile: True  # torch.compile the model + loss, True increases speed + decreases memory
---
> compile: False  # torch.compile the model + loss, True increases speed + decreases memory
88c79
<   _component_: torchtune.training.metric_logging.WandBLogger
---
>   _component_: torchtune.training.metric_logging.DiskLogger

The difference in throughput is due to both runs using the same 2 GPUs.

Full Distributed

tune cp llama3_2/1B_full 1B_full_validation

# Validation
dataset_validation:
  _component_: torchtune.datasets.alpaca_dataset
  source: mhenrichsen/alpaca_2k_test
  split: train
run_val_every_n_steps: 100
max_validation_batches: -1

metric_logger:
	_component_: torchtune.training.metric_logging.WandBLogger

Validation

tune run --nnodes 1 --nproc_per_node 2 \
    full_finetune_distributed --config 1B_full_validation.yaml \
	compile=True \
	enable_activation_checkpointing=True \
	enable_activation_offloading=True \
	optimizer_in_bwd=True \
	tokenizer.max_seq_len=4096 \
	gradient_accumulation_steps=1 \
	epochs=1 \
	batch_size=16 \
	output_dir=/tmp/torchtune/llama3_2_1B/full_dist_val

And without

tune run --nnodes 1 --nproc_per_node 2 \
    full_finetune_distributed --config 1B_full_validation.yaml \
	run_val_every_n_steps=None \
	compile=True \
	enable_activation_checkpointing=True \
	enable_activation_offloading=True \
	optimizer_in_bwd=True \
	tokenizer.max_seq_len=4096 \
	gradient_accumulation_steps=1 \
	epochs=1 \
	batch_size=16 \
	output_dir=/tmp/torchtune/llama3_2_1B/full_dist
Screenshot 2025-03-06 at 11 34 48
  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

The output is really simplistic and looks like

1|100|Loss: 3.8272886276245117:   6%|████                                                               | 100/1625 [01:05<05:46,  4.41it/s]
INFO:torchtune.utils._logging:Validation loss: 4.8242
1|200|Loss: 1.7956730127334595:  12%|████████▏                                                          | 200/1625 [01:30<05:11,  4.58it/s]
INFO:torchtune.utils._logging:Validation loss: 1.8639
1|300|Loss: 1.4479581117630005:  18%|████████████▎                                                      | 300/1625 [01:52<03:54,  5.66it/s]
INFO:torchtune.utils._logging:Validation loss: 1.7377
1|400|Loss: 1.433119535446167:  25%|████████████████▋                                                   | 400/1625 [02:13<03:51,  5.28it/s]
INFO:torchtune.utils._logging:Validation loss: 1.7247
1|500|Loss: 1.5169874429702759:  31%|████████████████████▌                                              | 500/1625 [02:34<03:13,  5.81it/s]
INFO:torchtune.utils._logging:Validation loss: 1.7714
1|600|Loss: 1.4446262121200562:  37%|████████████████████████▋                                          | 600/1625 [02:55<02:54,  5.89it/s]
INFO:torchtune.utils._logging:Validation loss: 1.7542
1|700|Loss: 1.5118173360824585:  43%|████████████████████████████▊                                      | 700/1625 [03:18<02:54,  5.32it/s]
INFO:torchtune.utils._logging:Validation loss: 1.7290
1|800|Loss: 1.2981687784194946:  49%|████████████████████████████████▉                                  | 800/1625 [03:40<02:09,  6.35it/s]
INFO:torchtune.utils._logging:Validation loss: 1.7309
1|900|Loss: 1.533372163772583:  55%|█████████████████████████████████████▋                              | 900/1625 [04:01<02:22,  5.08it/s]
INFO:torchtune.utils._logging:Validation loss: 1.7264

Please if you think it needs to be somehow improved.

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Mar 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2464

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 6, 2025
@felipemello1
Copy link
Contributor

nice! Let me review it this monday

@felipemello1 felipemello1 self-assigned this Mar 9, 2025
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks very good! Just left a few comments, let me know what you think

packed: False # True increases speed
seed: null
shuffle: True
batch_size: 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prob want to have a larger batch_size for validation, since we dont need the extra memory for activations + grads. Maybe batch_size_validation?

Not for this PR: We need to introduce a dataloader config, to manage bsz, shuffle, etc. I am working on it in a different PR.

Copy link
Author

@bzz bzz Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, but it would impede re-use of a single _loss_step() as-is, as it relies on self.ignore_labels_cache sized and setup up with cfg.batch_size that would then need to be different between training and validation 🤔

If you think it is worth adding it here - I'll appreciate an advice on the best way to go about it: passing in a flag to loss_step() (e.g. train=True), or add it as a state/property? Have 2 caches or check shapes & skip cache and slice logits?


labels = batch.pop("labels")

with self.activations_handling_ctx:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need to offload activations because there is no back propagation. We would need to double check if this is a no-op or if we are wasting resources by having it here. Either way, i think we can delete.

Copy link
Author

@bzz bzz Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out! Oh, so lora_distributed does not have a _loss_step() while full now does, so I guess deleting is not an option then.

I'll double-check its implications (if this is a no-op), but is there value in adding _loss_step() to lora distributed recipe and re-using the same approach for both in this PR?

else float("inf")
)

self._model.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ok here, but i wonder if it should be in the training loop

Copy link
Author

@bzz bzz Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it here on a premises that it's easier to reason about the model state "by inspection" when one can see it being changed before/after validation, in a single place.

Would c7095d8 do?

bzz added 5 commits March 17, 2025 22:45
Introduces new configuration options

dataset_validation:
run_val_every_n_steps: 100
max_validation_batches: -1

Test plan:
 - tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2/1B_lora_validation.yaml
Introduces the same configuration options as in the lora finetune distributed

dataset_validation:
run_val_every_n_steps: 100
max_validation_batches: -1

Test plan:
 - tune run --nproc_per_node 2 full_finetune_distributed --config ...
@bzz bzz force-pushed the validate_distributed branch from f516db7 to bac2247 Compare March 17, 2025 21:46
@bzz bzz force-pushed the validate_distributed branch from 96bb4f4 to 0f2861d Compare March 18, 2025 17:36
@bzz bzz force-pushed the validate_distributed branch from 0f2861d to c7095d8 Compare March 18, 2025 17:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants