-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements and fixes to gradient accumulation #993
base: main
Are you sure you want to change the base?
Conversation
axlearn/experiments/text/gpt/fuji.py
Outdated
# Note: the batch axes are different here than in | ||
# `cfg.batch_axis_names`, | ||
# as we partition sequence dim over `seq`. | ||
(None, 1): PartitionSpec(("data", "expert", "fsdp")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering, if we have a default input partition with axis=0 on ("data", "expert", "fsdp") and axis=1 on "seq", do we still need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick review.
(None, 1) is for the target_num_bytes
and (None, 2) is for the input_ids
and target_labels
, so we need both. Together they will work for most cases, but for the outliers where a specific sharding is required the ability to change sharding for the minibatches will be good to have.
Let me know if this answers your question.
), | ||
input_partition_spec(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, it seems rather a hack than a proper solution, that is, we want to have a different input_partition_spec()
than the default one, then we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed the default case, added it.
I think the below partition spec is good as a default, but the ability to change PartitionSpec might be good to have, what do you think?
(None, 1): PartitionSpec(("data", "expert", "fsdp")),
(None, 2): PartitionSpec(("data", "expert", "fsdp"), "seq"),
9b0f9a3
to
32a78ea
Compare
- Fix to with_minibatch_steps decorator to generate correct primal outputs shapes. - Improved with_minibatch_steps to take a minibatch_partitioner that contraints the input batch to the same PartitionSpec as Input Partitioner.
32a78ea
to
186e082
Compare
@@ -57,39 +59,38 @@ def _make_scan_minibatch_inputs( | |||
param_noise_key: Tensor, | |||
minibatch_size: int, | |||
minibatch_index: int, | |||
minibatch_partitioner: Optional[InputPartitionFn], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Echoing Kelvin's comment, could you explain concretely why we need this functionality? If it's just something that might be useful, maybe we can wait until we are certain that we will need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case where gradient accumulation is not enabled, the inputs to the graph are sharded as per the policy in input_partitioner. This ensures the batch dimension is sharded on data, expert and fsdp axes while sequence dimension is replicated on model axis.
Gradient accumulation wraps the train steps in a scan loop, while the input_partitioner shards the input batch to correctly at first. In the gradient accumulation wrapper the input batches are resharded/overridden by the function _make_scan_minibatch_inputs and sharded along all axes available which is probably unexpected and inefficient. Minibatches should follow the same PartitionSpec as input_batches.
The addition of the minibatch_partitioner allows the minibatches to use the same sharding/PartitionSpec as input_partitioner
provides in the input batches in the case gradient accumulation is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we just preserve the sharding the input already has, would that also address the concern about the input sharding being changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah preserving sharding of the input and not having a sharding_constraint
for minibatches would address the concern as well.
# Default partitioner for minibatches. | ||
if not minibatch_partitioner: | ||
minibatch_partitioner = partition_by_path_rank( | ||
path_rank_to_partition={ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we default this to the same sharding the input is already using along all non-batch axes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just confirming if I read it correctly, we want to default to input_partition_specs from utils.py
like before, and not what the input_partitioner sets.
Or the ask is to use the partition_by_path_rank
to replicate what input_partition_specs
was doing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly. I was envisioning that for all axes other than axis 0, we default to whatever sharding the input already has. For axis 0, ideally we could also keep whatever sharding the input already has too, although I'm not sure that would work with logical batching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For axis 0, ideally we could also keep whatever sharding the input already has too, although I'm not sure that would work with logical batching
I think preserving the sharding of the input would be perfect, logical batching already inserts the correct sharding constraint after squeezing out the padded batches
Removed additional sharding constraints from gradient accumulation decorator, minibatches now should use the sharding spec created by the |
8c16718
to
306f089
Compare
), | ||
input_partition_spec(), | ||
inputs["input_batch"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suppose we have a global input batch of size 100 running on 10 chips (so a per chip size of 10) and we want to switch to doing 10 grad accumulation steps each with a global batch size of 10 (1 per chip per accumulation step).
Suppose that the input is originally sharded evenly across the chips (first 10 on first chip, second 10 on second chip, etc). Then when we get the first slice of 10 for the first grad accumulation step, won't all these examples be on the same chip? Will that cause a problem? (E.g., if we worry XLA might not automatically reshard the examples across chips?)
Maybe we should reshard the batch axis only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can completely avoid the batch reshards using a reshape + transpose. I added it to the, PR let me know if it addresses your concerns.
Using the same example as @apghml:
Suppose we have a global input batch of size 100 running on 10 chips (so a per chip size of 10) and we want to switch to doing 10 grad accumulation steps each with a global batch size of 10 (1 per chip per accumulation step).
Suppose that the input is originally sharded evenly across the chips (first 10 on first chip, second 10 on second chip, etc). Then when we get the first slice of 10 for the first grad accumulation step, won't all these examples be on the same chip? Will that cause a problem? (E.g., if we worry XLA might not automatically reshard the examples across chips?)
Rather than using first 10 batches available in the global batch array for the first iteration, we construct the minibatch using the first batch from every device that is minibatch 0 =>[0, 10, 20 ....], minibatch 1 => [1, 11, 21, ...]. This is achieved using the reshape and transpose.
Essentially the logic here is to ensure each device uses local batches avoiding extra reshards.
This also scales well across multiple nodes as each node only runs a local reshape + transpose, also higher per device BS is also supported.
This should addresses the concerns around input batch reshards, let me know if there are still more concerns.
+1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size?
@kelvin-zou I can't think of a way to get size of a specific axis at runtime, but I do believe JAX should be able to give an informative error if the batch size % batch axis size != 0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Can you add a test that fails without this fix?
@@ -172,12 +167,26 @@ def fwd_helper( | |||
otherwise None. | |||
""" | |||
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps) | |||
|
|||
# Create a sample minibatch for the carry buffer creation below |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain in more detail why this is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw broadcasting errors coming from the scan body, (example below), JAX complained that the carry buffer shape and the output of minibatch step are incompatible.
PS below error where acc=4 and full batch size is 32
TypeError: add got incompatible shapes for broadcasting: (32, 4096, 3072), (8, 4096, 3072).
The carry buffer initialization uses the full batch while creating the buffer, which does not match with the output of minibatch step since it would use the shapes of minibatch.
The simple fix for this is to use a minibatch sample for creating carry buffer ensuring it's shapes are same as the minibatch step.
Let me know if I missed something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know why this issue wasn't causing errors before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unit test uses a toy model which does not have any metric/output that relies on batch size which is why it does not catch this issue. I dug a bit deeper and found that for fuji modelsoutput_collection/module_outputs/decoder/transformer/layer3/output
carries batch dimension in output - ref below.
path (GetAttrKey(name='output_collection'), GetAttrKey(name='module_outputs'), DictKey(key='decoder'), DictKey(key='transformer'), DictKey(key='layer3'), DictKey(key='output')) shape (32, 4096, 3072)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me, will approve once addressed @apghml 's comments.
@@ -172,12 +167,26 @@ def fwd_helper( | |||
otherwise None. | |||
""" | |||
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps) | |||
|
|||
# Create a sample minibatch for the carry buffer creation below |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
), | ||
input_partition_spec(), | ||
inputs["input_batch"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size?
"""Helper function that adds a minibatch dimension while evenly dividing | ||
batches across gradient accumulation iterations. | ||
|
||
Input dimension is [GBS, seq], this first reshaped to [MBS, steps, seq], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace the acronyms with full names?
), | ||
input_partition_spec(), | ||
inputs["input_batch"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Can you add a test that fails without this fix?
# Set up transpose to swap the first two dimensions. | ||
dims = list(range(x.ndim)) | ||
dims[0], dims[1] = dims[1], dims[0] | ||
return x.transpose(dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we replace these three lines with one line if we use jnp.moveaxis
?
with_minibatch_steps
decorator to generate correct primal outputs shapes.with_minibatch_steps
to take aminibatch_partitioner
that constraints the accumulation minibatch to the same PartitionSpec asinput_partitioner
.Misc: