-
Notifications
You must be signed in to change notification settings - Fork 63
enable fine tuning on HPU #552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! The block alignment for Gaudi should happen outside of the Multipack logic. Please update your PR so that samples are correctly aligned for Gaudi systems outside of the sampler logic.
@@ -211,11 +213,11 @@ def ffd_check_padding(a: np.ndarray, c: int, n: int): | |||
not_found = True | |||
for idx in range(n): | |||
# Calculate the new capacity if size is added to the bin | |||
new_capacity = max(bins_max_lengths[idx], size) * ( | |||
new_capacity = bucket(max(bins_max_lengths[idx], size)) * ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we using bucket
in multipack? This affects sampling for NVIDIA and AMD systems as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multipack decides how many samples and of what size we can fit in any given batch. Bucketing increases maximum sample length of the batch and thus changes batch layout. Moving bucketing out of multipack decreases bucketing efficiency. About impact on other platforms. bucket()
is defined as "identity" functions on anything except HPU, so should not have any impact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@splotnikv Why do you think that bucketing would be ineffective when moved outside of multipack? From multipack's perspective, it's just looking at lengths. If you move the buckets outside you should still get the same values I believe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionally, I'm seeing that bucket
only returns non-identity if Gaudi is available, so I think @splotnikv is correct that this doesn't impact cuda.
That being said, re: this implementation, the integration would be preferred as:
new_capacity = max(bins_max_lengths[idx], size) ...
if is_torch_hpu_available():
new_capacity = bucket(new_capacity)
rather than obfuscating the cuda path inside of the bucket
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JamesKunstle Regardless of bucket
's behavior on other platforms, Multipack implementation is intended as a generic algorithm which should only operate on what it believes are numbers.
It would be equivalent and preferred to simply run bucket
on all of the lengths ahead of time and pass them into Multipack, rather than updating the algorithm to now also consider what platform it's running on.
@@ -266,11 +268,11 @@ def ffd_with_result_padding(a: np.ndarray, c: int, start_index: int): | |||
add_new = True | |||
for idx in range(len(bins_max_lengths)): | |||
# Calculate the new capacity if size is added to the bin | |||
new_capacity = max(bins_max_lengths[idx], size) * ( | |||
new_capacity = bucket(max(bins_max_lengths[idx], size)) * ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't go in the logic for multipack
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see my reply above.
I've removed bucketing from multi-pack sampler for now, but we have to find a place there to implement it. I've left it in collate function but it is not enough. Please advise how we can do it. E.g. I can implement HPU specific multi-pack sampler. |
This pull request has merge conflicts that must be resolved before it can be |
In the next week or so we're going to implement some significantly better testing for this lib that'll give us more confidence in merging this. My ideal requirements to OK merging this are:
In general, this seems like a pretty minimally invasive integration, so I don't have much hesitation about merging it. I'd just like to make sure we've got good performance / correctness benchmarks ahead of time to hedge against regressions. |
@RobotSail, I've refactored bucketing implementation, moved it almost completely out of multipack sampler. One conditional call still remains in @JamesKunstle , I've done minimal testing on HPU on our site, everything seems to work as expected, more testing is in progress, I'll put results here when they'll be finished. I am also working on documentation, will add it in next commit. |
|
||
if is_torch_hpu_available(): | ||
bucket_v = np.vectorize(bucket) | ||
lengths = bucket_v(lengths) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@splotnikv This implementation makes more sense, but It seems like we are still mixing responsibilities with what is expected from the Multipack sampler here.
How can we provide the bucketed lengths prior to them even entering the sampler in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @RobotSail, unfortunately this is the best that we can do without refactoring a multipack. The problem is in how multipack handles sample lengths. It takes it from two different places (1) from dataset, see get_effective_samples_per_minibatch()
and (2) from arguments of MultipackDistributedBatchSampler
ctor. If I remove bucketing from generate_batches
, I'll have to move it to get_effective_samples_per_minibatch
and plus add one more call in setup_dataloader
. I.e. to call bucketing in all places where we create MultipackDistributedBatchSampler
. Let me know if you prefer this solution. I'll update PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@splotnikv I believe you will need to move this bucketing logic out of multipack entirely and into data processing anyway. Since our data processing script will drop samples below a certain length, we assume the data is below this threshold by the time we run the training loop (everything past the main
function call in main_ds.py
).
Since bucket
clips the lengths up to a certain length so that it can be properly padded on HPUs, there's a chance that your samples will actually be exceeding this threshold.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RobotSail , I removed bucketing completely from the MultipackDistributedBatchSampler
but it is still present in get_effective_samples_per_minibatch()
. Please have a look.
About sample length handling. Sorry, but I don't understand your concern. Bucketing does not change sample length, it remains exactly the same. All batch creation functionality remains the same and should be based on actual sample length. Bucketing should be done at the very end, when batch has been created and is ready to be sent to a model. E.g. samples longer than max_seq_len
should be dropped regardless of the bucketing. What bucketing changes is how samples are stored, think about it as std::vector::capacity
vs size
. BTW, if "capacity" exceeds packing_max_batch_len
it is fine.
This PR enables fine tuning on HPU. It supports FSDP, torch.compile and bucketing. HPU is autodetected and the change doesn't affect existing functionality. PR passed basic smoke testing on HPU, more testing is in progress.