-
Notifications
You must be signed in to change notification settings - Fork 19
[DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor) #326
base: main
Are you sure you want to change the base?
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
| self.assertTrue( | ||
| isinstance(colwise_param, DTensor) | ||
| and isinstance( | ||
| colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
editted: without this PR, torch.chunk returns bf16 tensor. FSDP2 happens after TP, thus only see Float8Linear(weight=DTensor(_local_tensor=Tensor))
with this PR, torch.chunk returns WeightWithDynamicFloat8CastTensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain where the bf16 came from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct my word to be accurate: without this PR, torch.chunk returns plain Tensor (can be fp32 or bf16) instead of WeightWithDynamicFloat8CastTensor
| torch.ops.aten.as_strided.default, | ||
| torch.ops.aten._to_copy.default, | ||
| torch.ops.aten._pin_memory.default, | ||
| torch.ops.aten.split.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten.split is from torch.chunk, when calling from distribute_tensor during TP init
editted: @awgu curious if you still remember the reason to return Tensor from torch.chunk instead of WeightWithDynamicFloat8CastTensor. Is it for padding? any concerns if I prefer torch.chunk to returning WeightWithDynamicFloat8CastTensor ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awgu curious if you still remember the reason to return bf16 from torch.chunk.
I thought that dtype and whether is WeightWithDynamicFloat8CastTensor are orthogonal. Do you mean the latter (whether is WeightWithDynamicFloat8CastTensor or not?
I think originally I only added the ops that I saw I needed. Adding aten.split and aten.clone seems okay to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whether is WeightWithDynamicFloat8CastTensor or not
exactly, WeightWithDynamicFloat8CastTensor or not is the key. I edited my previous comments to say right now torch.chunk returns Tensor
I think originally I only added the ops that I saw I needed
changing torch.chunk affects both TP and FSDP2. will double check FSDP2 after the change
| elif isinstance(out, DTensor) and isinstance( | ||
| out._local_tensor, Float8Tensor | ||
| ): | ||
| out._local_tensor._scale = scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure about this change yet. just want to have someting sketchy to discuss first
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
draft this PR for discussion, before having something landable
we see 2 problems in float8 all-gather FSDP2 + TP
weight, but expect all-reduce only forinputcrux is how we dispatch
torch.chunk, which is called fromdistribute_tensorfor TP inittorch.chunkreturnsTensor. FSDP2 happens after TP, thus only seeFloat8Linear(weight=DTensor(_local_tensor=Tensor))torch.chunkreturnsWeightWithDynamicFloat8CastTensorprofiler trace without this PR: AR (all-reduce) for input -> AG (all-gather) -> 4 ARs for wq,k,v,o -> 1 AR for input. 4 ARs for wq,k,v,o should not happen if we precompute amax/scales for
model.parameters()afteropt.step()