-
Notifications
You must be signed in to change notification settings - Fork 562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add validation dataset loss to distributed SFT recipies #2464
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2464
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
nice! Let me review it this monday |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks very good! Just left a few comments, let me know what you think
packed: False # True increases speed | ||
seed: null | ||
shuffle: True | ||
batch_size: 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prob want to have a larger batch_size for validation, since we dont need the extra memory for activations + grads. Maybe batch_size_validation
?
Not for this PR: We need to introduce a dataloader config, to manage bsz, shuffle, etc. I am working on it in a different PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, but it would impede re-use of a single _loss_step()
as-is, as it relies on self.ignore_labels_cache
sized and setup up with cfg.batch_size
that would then need to be different between training and validation 🤔
If you think it is worth adding it here - I'll appreciate an advice on the best way to go about it: passing in a flag to loss_step() (e.g. train=True
), or add it as a state/property? Have 2 caches or check shapes & skip cache and slice logits?
|
||
labels = batch.pop("labels") | ||
|
||
with self.activations_handling_ctx: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We dont need to offload activations because there is no back propagation. We would need to double check if this is a no-op or if we are wasting resources by having it here. Either way, i think we can delete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing that out! Oh, so lora_distributed does not have a _loss_step()
while full now does, so I guess deleting is not an option then.
I'll double-check its implications (if this is a no-op), but is there value in adding _loss_step()
to lora distributed recipe and re-using the same approach for both in this PR?
recipes/lora_finetune_distributed.py
Outdated
else float("inf") | ||
) | ||
|
||
self._model.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is ok here, but i wonder if it should be in the training loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it here on a premises that it's easier to reason about the model state "by inspection" when one can see it being changed before/after validation, in a single place.
Would c7095d8 do?
Introduces new configuration options dataset_validation: run_val_every_n_steps: 100 max_validation_batches: -1 Test plan: - tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2/1B_lora_validation.yaml
Introduces the same configuration options as in the lora finetune distributed dataset_validation: run_val_every_n_steps: 100 max_validation_batches: -1 Test plan: - tune run --nproc_per_node 2 full_finetune_distributed --config ...
f516db7
to
bac2247
Compare
96bb4f4
to
0f2861d
Compare
0f2861d
to
c7095d8
Compare
Context
What is the purpose of this PR? Is it to
Addresses #1042 / part of the #883 for distributed recipies
Changelog
What are the changes made in this PR?
New configuration:
max_validation_batches: -1Test plan
As suggested, I run with compile + opt_in_bwd + activation ckpt + activation offloading on 2xH100 using mhenrichsen/alpaca_2k_test as a validation dataset.
LoRA Distributed
Without validation
The difference in throughput is due to both runs using the same 2 GPUs.
Full Distributed
Validation
And without
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
The output is really simplistic and looks like
Please if you think it needs to be somehow improved.